From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/events_worker.py | 1454 +++++++++++++++++++++++ 1 file changed, 1454 insertions(+) create mode 100644 synapse/storage/databases/main/events_worker.py (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py new file mode 100644 index 0000000000..a7b7393f6e --- /dev/null +++ b/synapse/storage/databases/main/events_worker.py @@ -0,0 +1,1454 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import itertools +import logging +import threading +from collections import namedtuple +from typing import List, Optional, Tuple + +from constantly import NamedConstant, Names + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import NotFoundError, SynapseError +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) +from synapse.events import make_event_from_dict +from synapse.events.utils import prune_event +from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import BackfillStream +from synapse.replication.tcp.streams.events import EventsStream +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import DatabasePool +from synapse.storage.types import Cursor +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import get_domain_from_id +from synapse.util.caches.descriptors import ( + Cache, + _CacheContext, + cached, + cachedInlineCallbacks, +) +from synapse.util.iterutils import batch_iter +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + +class EventsWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) + + if hs.config.worker.writers.events == hs.get_instance_name(): + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + + self._get_event_cache = Cache( + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, + ) + + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(-token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def get_received_ts(self, event_id): + """Get received_ts (when it was persisted) for the event. + + Raises an exception for unknown events. + + Args: + event_id (str) + + Returns: + Deferred[int|None]: Timestamp in milliseconds, or None for events + that were persisted before received_ts was implemented. + """ + return self.db_pool.simple_select_one_onecol( + table="events", + keyvalues={"event_id": event_id}, + retcol="received_ts", + desc="get_received_ts", + ) + + def get_received_ts_by_stream_pos(self, stream_ordering): + """Given a stream ordering get an approximate timestamp of when it + happened. + + This is done by simply taking the received ts of the first event that + has a stream ordering greater than or equal to the given stream pos. + If none exists returns the current time, on the assumption that it must + have happened recently. + + Args: + stream_ordering (int) + + Returns: + Deferred[int] + """ + + def _get_approximate_received_ts_txn(txn): + sql = """ + SELECT received_ts FROM events + WHERE stream_ordering >= ? + LIMIT 1 + """ + + txn.execute(sql, (stream_ordering,)) + row = txn.fetchone() + if row and row[0]: + ts = row[0] + else: + ts = self.clock.time_msec() + + return ts + + return self.db_pool.runInteraction( + "get_approximate_received_ts", _get_approximate_received_ts_txn + ) + + @defer.inlineCallbacks + def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, + ): + """Get an event from the database by event_id. + + Args: + event_id: The event_id of the event to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (behave as per allow_none + if the event is redacted) + + get_prev_content: If True and event is a state event, + include the previous states content in the unsigned field. + + allow_rejected: If True, return rejected events. Otherwise, + behave as per allow_none. + + allow_none: If True, return None if no event found, if + False throw a NotFoundError + + check_room_id: if not None, check the room of the found event. + If there is a mismatch, behave as per allow_none. + + Returns: + Deferred[EventBase|None] + """ + if not isinstance(event_id, str): + raise TypeError("Invalid event event_id %r" % (event_id,)) + + events = yield self.get_events_as_list( + [event_id], + redact_behaviour=redact_behaviour, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event = events[0] if events else None + + if event is not None and check_room_id is not None: + if event.room_id != check_room_id: + event = None + + if event is None and not allow_none: + raise NotFoundError("Could not find event %s" % (event_id,)) + + return event + + @defer.inlineCallbacks + def get_events( + self, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + ): + """Get events from the database + + Args: + event_ids: The event_ids of the events to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (omit them from the response) + + get_prev_content: If True and event is a state event, + include the previous states content in the unsigned field. + + allow_rejected: If True, return rejected events. Otherwise, + omits rejeted events from the response. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self.get_events_as_list( + event_ids, + redact_behaviour=redact_behaviour, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return {e.event_id: e for e in events} + + @defer.inlineCallbacks + def get_events_as_list( + self, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + ): + """Get events from the database and return in a list in the same order + as given by `event_ids` arg. + + Unknown events will be omitted from the response. + + Args: + event_ids: The event_ids of the events to fetch + + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events (omit them from the response) + + get_prev_content: If True and event is a state event, + include the previous states content in the unsigned field. + + allow_rejected: If True, return rejected events. Otherwise, + omits rejected events from the response. + + Returns: + Deferred[list[EventBase]]: List of events fetched from the database. The + events are in the same order as `event_ids` arg. + + Note that the returned list may be smaller than the list of event + IDs if not all events could be fetched. + """ + + if not event_ids: + return [] + + # there may be duplicates so we cast the list to a set + event_entry_map = yield self._get_events_from_cache_or_db( + set(event_ids), allow_rejected=allow_rejected + ) + + events = [] + for event_id in event_ids: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if not allow_rejected: + assert not entry.event.rejected_reason, ( + "rejected event returned from _get_events_from_cache_or_db despite " + "allow_rejected=False" + ) + + # We may not have had the original event when we received a redaction, so + # we have to recheck auth now. + + if not allow_rejected and entry.event.type == EventTypes.Redaction: + if entry.event.redacts is None: + # A redacted redaction doesn't have a `redacts` key, in + # which case lets just withhold the event. + # + # Note: Most of the time if the redactions has been + # redacted we still have the un-redacted event in the DB + # and so we'll still see the `redacts` key. However, this + # isn't always true e.g. if we have censored the event. + logger.debug( + "Withholding redaction event %s as we don't have redacts key", + event_id, + ) + continue + + redacted_event_id = entry.event.redacts + event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + original_event_entry = event_map.get(redacted_event_id) + if not original_event_entry: + # we don't have the redacted event (or it was rejected). + # + # We assume that the redaction isn't authorized for now; if the + # redacted event later turns up, the redaction will be re-checked, + # and if it is found valid, the original will get redacted before it + # is served to the client. + logger.debug( + "Withholding redaction event %s since we don't (yet) have the " + "original %s", + event_id, + redacted_event_id, + ) + continue + + original_event = original_event_entry.event + if original_event.type == EventTypes.Create: + # we never serve redactions of Creates to clients. + logger.info( + "Withholding redaction %s of create event %s", + event_id, + redacted_event_id, + ) + continue + + if original_event.room_id != entry.event.room_id: + logger.info( + "Withholding redaction %s of event %s from a different room", + event_id, + redacted_event_id, + ) + continue + + if entry.event.internal_metadata.need_to_check_redaction(): + original_domain = get_domain_from_id(original_event.sender) + redaction_domain = get_domain_from_id(entry.event.sender) + if original_domain != redaction_domain: + # the senders don't match, so this is forbidden + logger.info( + "Withholding redaction %s whose sender domain %s doesn't " + "match that of redacted event %s %s", + event_id, + redaction_domain, + redacted_event_id, + original_domain, + ) + continue + + # Update the cache to save doing the checks again. + entry.event.internal_metadata.recheck_redaction = False + + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + return events + + @defer.inlineCallbacks + def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the cache or the database. + + If events are pulled from the database, they will be cached for future lookups. + + Unknown events are omitted from the response. + + Args: + + event_ids (Iterable[str]): The event_ids of the events to fetch + + allow_rejected (bool): Whether to include rejected events. If False, + rejected events are omitted from the response. + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result + """ + event_entry_map = self._get_events_from_cache( + event_ids, allow_rejected=allow_rejected + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + log_ctx = current_context() + log_ctx.record_event_fetch(len(missing_events_ids)) + + # Note that _get_events_from_db is also responsible for turning db rows + # into FrozenEvents (via _get_event_from_row), which involves seeing if + # the events have been redacted, and if so pulling the redaction event out + # of the database to check it. + # + missing_events = yield self._get_events_from_db( + missing_events_ids, allow_rejected=allow_rejected + ) + + event_entry_map.update(missing_events) + + return event_entry_map + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (Iterable[str]): list of event_ids to fetch + allow_rejected (bool): Whether to return events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, update_metrics=update_metrics + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + self._fetch_event_list(conn, event_list) + + def _fetch_event_list(self, conn, event_list): + """Handle a load of requests from the _event_fetch_list queue + + Args: + conn (twisted.enterprise.adbapi.Connection): database connection + + event_list (list[Tuple[list[str], Deferred]]): + The fetch requests. Each entry consists of a list of event + ids to be fetched, and a deferred to be completed once the + events have been fetched. + + The deferreds are callbacked with a dictionary mapping from event id + to event row. Note that it may well contain additional events that + were not part of this request. + """ + with Measure(self._clock, "_fetch_event_list"): + try: + events_to_fetch = { + event_id for events, _ in event_list for event_id in events + } + + row_dict = self.db_pool.new_transaction( + conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch + ) + + # We only want to resolve deferreds from the main thread + def fire(): + for _, d in event_list: + d.callback(row_dict) + + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs, exc): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(exc) + + with PreserveLoggingContext(): + self.hs.get_reactor().callFromThread(fire, event_list, e) + + @defer.inlineCallbacks + def _get_events_from_db(self, event_ids, allow_rejected=False): + """Fetch a bunch of events from the database. + + Returned events will be added to the cache for future lookups. + + Unknown events are omitted from the response. + + Args: + event_ids (Iterable[str]): The event_ids of the events to fetch + + allow_rejected (bool): Whether to include rejected events. If False, + rejected events are omitted from the response. + + Returns: + Deferred[Dict[str, _EventCacheEntry]]: + map from event id to result. May return extra events which + weren't asked for. + """ + fetched_events = {} + events_to_fetch = event_ids + + while events_to_fetch: + row_map = yield self._enqueue_events(events_to_fetch) + + # we need to recursively fetch any redactions of those events + redaction_ids = set() + for event_id in events_to_fetch: + row = row_map.get(event_id) + fetched_events[event_id] = row + if row: + redaction_ids.update(row["redactions"]) + + events_to_fetch = redaction_ids.difference(fetched_events.keys()) + if events_to_fetch: + logger.debug("Also fetching redaction events %s", events_to_fetch) + + # build a map from event_id to EventBase + event_map = {} + for event_id, row in fetched_events.items(): + if not row: + continue + assert row["event_id"] == event_id + + rejected_reason = row["rejected_reason"] + + if not allow_rejected and rejected_reason: + continue + + d = db_to_json(row["json"]) + internal_metadata = db_to_json(row["internal_metadata"]) + + format_version = row["format_version"] + if format_version is None: + # This means that we stored the event before we had the concept + # of a event format version, so it must be a V1 event. + format_version = EventFormatVersions.V1 + + room_version_id = row["room_version_id"] + + if not room_version_id: + # this should only happen for out-of-band membership events + if not internal_metadata.get("out_of_band_membership"): + logger.warning( + "Room %s for event %s is unknown", d["room_id"], event_id + ) + continue + + # take a wild stab at the room version based on the event format + if format_version == EventFormatVersions.V1: + room_version = RoomVersions.V1 + elif format_version == EventFormatVersions.V2: + room_version = RoomVersions.V3 + else: + room_version = RoomVersions.V5 + else: + room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not room_version: + logger.warning( + "Event %s in room %s has unknown room version %s", + event_id, + d["room_id"], + room_version_id, + ) + continue + + if room_version.event_format != format_version: + logger.error( + "Event %s in room %s with version %s has wrong format: " + "expected %s, was %s", + event_id, + d["room_id"], + room_version_id, + room_version.event_format, + format_version, + ) + continue + + original_ev = make_event_from_dict( + event_dict=d, + room_version=room_version, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + event_map[event_id] = original_ev + + # finally, we can decide whether each one needs redacting, and build + # the cache entries. + result_map = {} + for event_id, original_ev in event_map.items(): + redactions = fetched_events[event_id]["redactions"] + redacted_event = self._maybe_redact_event_row( + original_ev, redactions, event_map + ) + + cache_entry = _EventCacheEntry( + event=original_ev, redacted_event=redacted_event + ) + + self._get_event_cache.prefill((event_id,), cache_entry) + result_map[event_id] = cache_entry + + return result_map + + @defer.inlineCallbacks + def _enqueue_events(self, events): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + + Args: + events (Iterable[str]): events to be fetched. + + Returns: + Deferred[Dict[str, Dict]]: map from event id to row data from the database. + May contain events that weren't requested. + """ + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append((events, events_d)) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + run_as_background_process( + "fetch_events", self.db_pool.runWithConnection, self._do_fetch + ) + + logger.debug("Loading %d events: %s", len(events), events) + with PreserveLoggingContext(): + row_map = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) + + return row_map + + def _fetch_event_rows(self, txn, event_ids): + """Fetch event rows from the database + + Events which are not found are omitted from the result. + + The returned per-event dicts contain the following keys: + + * event_id (str) + + * json (str): json-encoded event structure + + * internal_metadata (str): json-encoded internal metadata dict + + * format_version (int|None): The format of the event. Hopefully one + of EventFormatVersions. 'None' means the event predates + EventFormatVersions (so the event is format V1). + + * room_version_id (str|None): The version of the room which contains the event. + Hopefully one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + + * rejected_reason (str|None): if the event was rejected, the reason + why. + + * redactions (List[str]): a list of event-ids which (claim to) redact + this event. + + Args: + txn (twisted.enterprise.adbapi.Connection): + event_ids (Iterable[str]): event IDs to fetch + + Returns: + Dict[str, Dict]: a map from event id to event info. + """ + event_dict = {} + for evs in batch_iter(event_ids, 200): + sql = """\ + SELECT + e.event_id, + e.internal_metadata, + e.json, + e.format_version, + r.room_version, + rej.reason + FROM event_json as e + LEFT JOIN rooms r USING (room_id) + LEFT JOIN rejections as rej USING (event_id) + WHERE """ + + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", evs + ) + + txn.execute(sql + clause, args) + + for row in txn: + event_id = row[0] + event_dict[event_id] = { + "event_id": event_id, + "internal_metadata": row[1], + "json": row[2], + "format_version": row[3], + "room_version_id": row[4], + "rejected_reason": row[5], + "redactions": [], + } + + # check for redactions + redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " + + clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) + + txn.execute(redactions_sql + clause, args) + + for (redacter, redacted) in txn: + d = event_dict.get(redacted) + if d: + d["redactions"].append(redacter) + + return event_dict + + def _maybe_redact_event_row(self, original_ev, redactions, event_map): + """Given an event object and a list of possible redacting event ids, + determine whether to honour any of those redactions and if so return a redacted + event. + + Args: + original_ev (EventBase): + redactions (iterable[str]): list of event ids of potential redaction events + event_map (dict[str, EventBase]): other events which have been fetched, in + which we can look up the redaaction events. Map from event id to event. + + Returns: + Deferred[EventBase|None]: if the event should be redacted, a pruned + event object. Otherwise, None. + """ + if original_ev.type == "m.room.create": + # we choose to ignore redactions of m.room.create events. + return None + + for redaction_id in redactions: + redaction_event = event_map.get(redaction_id) + if not redaction_event or redaction_event.rejected_reason: + # we don't have the redaction event, or the redaction event was not + # authorized. + logger.debug( + "%s was redacted by %s but redaction not found/authed", + original_ev.event_id, + redaction_id, + ) + continue + + if redaction_event.room_id != original_ev.room_id: + logger.debug( + "%s was redacted by %s but redaction was in a different room!", + original_ev.event_id, + redaction_id, + ) + continue + + # Starting in room version v3, some redactions need to be + # rechecked if we didn't have the redacted event at the + # time, so we recheck on read instead. + if redaction_event.internal_metadata.need_to_check_redaction(): + expected_domain = get_domain_from_id(original_ev.sender) + if get_domain_from_id(redaction_event.sender) == expected_domain: + # This redaction event is allowed. Mark as not needing a recheck. + redaction_event.internal_metadata.recheck_redaction = False + else: + # Senders don't match, so the event isn't actually redacted + logger.debug( + "%s was redacted by %s but the senders don't match", + original_ev.event_id, + redaction_id, + ) + continue + + logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id) + + # we found a good redaction event. Redact! + redacted_event = prune_event(original_ev) + redacted_event.unsigned["redacted_by"] = redaction_id + + # It's fine to add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = redaction_event + + return redacted_event + + # no valid redaction found for this event + return None + + @defer.inlineCallbacks + def have_events_in_timeline(self, event_ids): + """Given a list of event ids, check if we have already processed and + stored them as non outliers. + """ + rows = yield self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) + + return {r["event_id"] for r in rows} + + @defer.inlineCallbacks + def have_seen_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Args: + event_ids (iterable[str]): + + Returns: + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = "SELECT event_id FROM events as e WHERE " + clause, args = make_in_list_sql_clause( + txn.database_engine, "e.event_id", chunk + ) + txn.execute(sql + clause, args) + for (event_id,) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + yield self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn, chunk + ) + return results + + def _get_total_state_event_counts_txn(self, txn, room_id): + """ + See get_total_state_event_counts. + """ + # We join against the events table as that has an index on room_id + sql = """ + SELECT COUNT(*) FROM state_events + INNER JOIN events USING (room_id, event_id) + WHERE room_id=? + """ + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_total_state_event_counts(self, room_id): + """ + Gets the total number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.db_pool.runInteraction( + "get_total_state_event_counts", + self._get_total_state_event_counts_txn, + room_id, + ) + + def _get_current_state_event_counts_txn(self, txn, room_id): + """ + See get_current_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_current_state_event_counts(self, room_id): + """ + Gets the current number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.db_pool.runInteraction( + "get_current_state_event_counts", + self._get_current_state_event_counts_txn, + room_id, + ) + + @defer.inlineCallbacks + def get_room_complexity(self, room_id): + """ + Get a rough approximation of the complexity of the room. This is used by + remote servers to decide whether they wish to join the room or not. + Higher complexity value indicates that being in the room will consume + more resources. + + Args: + room_id (str) + + Returns: + Deferred[dict[str:int]] of complexity version to complexity. + """ + state_events = yield self.get_current_state_event_counts(room_id) + + # Call this one "v1", so we can introduce new ones as we want to develop + # it. + complexity_v1 = round(state_events / 500, 2) + + return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + """Returns new events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.db_pool.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_ex_outlier_stream_rows(self, last_id, current_id): + """Returns de-outliered events, for the Events replication stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + + Returns: Deferred[List[Tuple]] + a list of events stream rows. Each tuple consists of a stream id as + the first element, followed by fields suitable for casting into an + EventsStreamRow. + """ + + def get_ex_outlier_stream_rows_txn(txn): + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering ASC" + ) + + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.db_pool.runInteraction( + "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn + ) + + async def get_all_new_backfill_event_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for backfill replication stream, including all new + backfilled events and events that have gone from being outliers to not. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: + return [], current_id, False + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = [(row[0], row[1:]) for row in txn] + + limited = False + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + limited = True + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend((row[0], row[1:]) for row in txn) + + if len(new_event_updates) >= limit: + upper_bound = new_event_updates[-1][0] + limited = True + + return new_event_updates, upper_bound, limited + + return await self.db_pool.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, target_row_count)) + return txn.fetchall() + + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db_pool.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db_pool.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token + ) + + return rows, to_token, True + + @cached(num_args=5, max_entries=10) + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + have_backfill_events = last_backfill_id != current_backfill_id + have_forward_events = last_forward_id != current_forward_id + + if not have_backfill_events and not have_forward_events: + return defer.succeed(AllNewEventsResult([], [], [], [], [])) + + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + if have_forward_events: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() + else: + new_forward_events = [] + forward_ex_outliers = [] + + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + if have_backfill_events: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() + else: + new_backfill_events = [] + backward_ex_outliers = [] + + return AllNewEventsResult( + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, + ) + + return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) + + async def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) + return (to_1, so_1) > (to_2, so_2) + + @cachedInlineCallbacks(max_entries=5000) + def get_event_ordering(self, event_id): + res = yield self.db_pool.simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + return (int(res["topological_ordering"]), int(res["stream_ordering"])) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.db_pool.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + + @cached(tree=True, cache_context=True) + async def get_unread_message_count_for_user( + self, room_id: str, user_id: str, cache_context: _CacheContext, + ) -> int: + """Retrieve the count of unread messages for the given room and user. + + Args: + room_id: The ID of the room to count unread messages in. + user_id: The ID of the user to count unread messages for. + + Returns: + The number of unread messages for the given user in the given room. + """ + with Measure(self._clock, "get_unread_message_count_for_user"): + last_read_event_id = await self.get_last_receipt_event_id_for_user( + user_id=user_id, + room_id=room_id, + receipt_type="m.read", + on_invalidate=cache_context.invalidate, + ) + + return await self.db_pool.runInteraction( + "get_unread_message_count_for_user", + self._get_unread_message_count_for_user_txn, + user_id, + room_id, + last_read_event_id, + ) + + def _get_unread_message_count_for_user_txn( + self, + txn: Cursor, + user_id: str, + room_id: str, + last_read_event_id: Optional[str], + ) -> int: + if last_read_event_id: + # Get the stream ordering for the last read event. + stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, + retcol="stream_ordering", + ) + else: + # If there's no read receipt for that room, it probably means the user hasn't + # opened it yet, in which case use the stream ID of their join event. + # We can't just set it to 0 otherwise messages from other local users from + # before this user joined will be counted as well. + txn.execute( + """ + SELECT stream_ordering FROM local_current_membership + LEFT JOIN events USING (event_id, room_id) + WHERE membership = 'join' + AND user_id = ? + AND room_id = ? + """, + (user_id, room_id), + ) + row = txn.fetchone() + + if row is None: + return 0 + + stream_ordering = row[0] + + # Count the messages that qualify as unread after the stream ordering we've just + # retrieved. + sql = """ + SELECT COUNT(*) FROM events + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread + """ + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + + return row[0] if row else 0 + + +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) -- cgit 1.5.1 From 2ffd6783c7af12e3c29e1a44dee4a9deeb83890b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 6 Aug 2020 17:15:35 +0100 Subject: Revert #7736 (#8039) --- changelog.d/7736.feature | 1 - changelog.d/8039.misc | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 - synapse/push/push_tools.py | 17 ++- synapse/rest/client/v2_alpha/sync.py | 1 - synapse/storage/databases/main/cache.py | 1 - synapse/storage/databases/main/events.py | 48 +------ synapse/storage/databases/main/events_worker.py | 86 +---------- .../main/schema/delta/58/12unread_messages.sql | 18 --- tests/rest/client/v1/utils.py | 20 --- tests/rest/client/v2_alpha/test_sync.py | 157 +-------------------- 12 files changed, 19 insertions(+), 339 deletions(-) delete mode 100644 changelog.d/7736.feature create mode 100644 changelog.d/8039.misc delete mode 100644 synapse/storage/databases/main/schema/delta/58/12unread_messages.sql (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature deleted file mode 100644 index feb02be234..0000000000 --- a/changelog.d/7736.feature +++ /dev/null @@ -1 +0,0 @@ -Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). diff --git a/changelog.d/8039.misc b/changelog.d/8039.misc new file mode 100644 index 0000000000..599933c80e --- /dev/null +++ b/changelog.d/8039.misc @@ -0,0 +1 @@ +Revert MSC2654 implementation because of perf issues. Please delete this line when processing the 1.19 changelog. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index ae5e1810fc..a34bdf1830 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -67,7 +67,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url", "count_as_unread"], + "events": ["processed", "outlier", "contains_url"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5a19bac929..c42dac18f5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,7 +103,6 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) - unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1887,10 +1886,6 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] - - unread_count = await self.store.get_unread_message_count_for_user( - room_id, sync_config.user.to_string(), - ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1899,7 +1894,6 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, - unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index bc8f71916b..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,13 +21,22 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + badge = len(invites) for room_id in joins: - unread_count = await store.get_unread_message_count_for_user(room_id, user_id) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if unread_count else 0 + if room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[room_id] + + notifs = await ( + store.get_unread_event_push_actions_by_room_for_user( + room_id, user_id, last_unread_event_id + ) + ) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if notifs["notify_count"] else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 3f5bf75e59..a5c24fbd63 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,7 +426,6 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 683afde52b..10de446065 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -172,7 +172,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4d8a24ce4b..1a68bf32cb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -53,47 +53,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -STATE_EVENT_TYPES_TO_MARK_UNREAD = { - EventTypes.Topic, - EventTypes.Name, - EventTypes.RoomAvatar, - EventTypes.Tombstone, -} - - -def should_count_as_unread(event: EventBase, context: EventContext) -> bool: - # Exclude rejected and soft-failed events. - if context.rejected or event.internal_metadata.is_soft_failed(): - return False - - # Exclude notices. - if ( - not event.is_state() - and event.type == EventTypes.Message - and event.content.get("msgtype") == "m.notice" - ): - return False - - # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: - return False - - # Mark events that have a non-empty string body as unread. - body = event.content.get("body") - if isinstance(body, str) and body: - return True - - # Mark some state events as unread. - if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: - return True - - # Mark encrypted events as unread. - if not event.is_state() and event.type == EventTypes.Encrypted: - return True - - return False - def encode_json(json_object): """ @@ -239,10 +198,6 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - self.store.get_unread_message_count_for_user.invalidate_many( - (event.room_id,), - ) - for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -864,9 +819,8 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), - "count_as_unread": should_count_as_unread(event, context), } - for event, context in events_and_contexts + for event, _ in events_and_contexts ], ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7b7393f6e..755b7a2a85 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -41,15 +41,9 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import ( - Cache, - _CacheContext, - cached, - cachedInlineCallbacks, -) +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1364,84 +1358,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - @cached(tree=True, cache_context=True) - async def get_unread_message_count_for_user( - self, room_id: str, user_id: str, cache_context: _CacheContext, - ) -> int: - """Retrieve the count of unread messages for the given room and user. - - Args: - room_id: The ID of the room to count unread messages in. - user_id: The ID of the user to count unread messages for. - - Returns: - The number of unread messages for the given user in the given room. - """ - with Measure(self._clock, "get_unread_message_count_for_user"): - last_read_event_id = await self.get_last_receipt_event_id_for_user( - user_id=user_id, - room_id=room_id, - receipt_type="m.read", - on_invalidate=cache_context.invalidate, - ) - - return await self.db_pool.runInteraction( - "get_unread_message_count_for_user", - self._get_unread_message_count_for_user_txn, - user_id, - room_id, - last_read_event_id, - ) - - def _get_unread_message_count_for_user_txn( - self, - txn: Cursor, - user_id: str, - room_id: str, - last_read_event_id: Optional[str], - ) -> int: - if last_read_event_id: - # Get the stream ordering for the last read event. - stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn=txn, - table="events", - keyvalues={"room_id": room_id, "event_id": last_read_event_id}, - retcol="stream_ordering", - ) - else: - # If there's no read receipt for that room, it probably means the user hasn't - # opened it yet, in which case use the stream ID of their join event. - # We can't just set it to 0 otherwise messages from other local users from - # before this user joined will be counted as well. - txn.execute( - """ - SELECT stream_ordering FROM local_current_membership - LEFT JOIN events USING (event_id, room_id) - WHERE membership = 'join' - AND user_id = ? - AND room_id = ? - """, - (user_id, room_id), - ) - row = txn.fetchone() - - if row is None: - return 0 - - stream_ordering = row[0] - - # Count the messages that qualify as unread after the stream ordering we've just - # retrieved. - sql = """ - SELECT COUNT(*) FROM events - WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread - """ - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - return row[0] if row else 0 - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql deleted file mode 100644 index 531b532c73..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store a boolean value in the events table for whether the event should be counted in --- the unread_count property of sync responses. -ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 51941f99f9..8933b560d2 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -165,26 +165,6 @@ class RestHelper(object): return channel.json_body - def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") - ) - render(request, self.resource, self.hs.get_reactor()) - - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) - ) - - return channel.json_body - def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a31e44c97e..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import read_marker, sync +from synapse.rest.client.v2_alpha import sync from tests import unittest from tests.server import TimedOutException @@ -324,156 +324,3 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.render(request) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: True): - """Syncs and compares the unread count with the expected value.""" - - request, channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, - ) - self.render(request) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] -- cgit 1.5.1 From 6b7ce1d332766bc2da9e99b22a452e0813a2aae3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 09:25:40 -0400 Subject: Remove some unused database functions. (#8085) --- changelog.d/8085.misc | 1 + synapse/storage/databases/main/event_federation.py | 13 -- synapse/storage/databases/main/events_worker.py | 170 +-------------------- synapse/storage/databases/main/presence.py | 21 --- synapse/storage/databases/main/registration.py | 37 ----- synapse/storage/databases/main/room.py | 4 - .../delta/58/13remove_presence_allow_inbound.sql | 17 +++ 7 files changed, 19 insertions(+), 244 deletions(-) create mode 100644 changelog.d/8085.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8085.misc b/changelog.d/8085.misc new file mode 100644 index 0000000000..c3da1e297c --- /dev/null +++ b/changelog.d/8085.misc @@ -0,0 +1 @@ +Remove some unused database functions. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 484875f989..431bd76693 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_in_room(self, room_id): - return self.db_pool.runInteraction( - "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id - ) - def get_oldest_events_with_depth_in_room(self, room_id): return self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", @@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.db_pool.simple_select_onecol_txn( - txn, - table="event_backward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - def get_prev_events_for_room(self, room_id: str): """ Gets a subset of the current forward extremities in the given room. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 755b7a2a85..5687448e3d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - def get_received_ts_by_stream_pos(self, stream_ordering): - """Given a stream ordering get an approximate timestamp of when it - happened. - - This is done by simply taking the received ts of the first event that - has a stream ordering greater than or equal to the given stream pos. - If none exists returns the current time, on the assumption that it must - have happened recently. - - Args: - stream_ordering (int) - - Returns: - Deferred[int] - """ - - def _get_approximate_received_ts_txn(txn): - sql = """ - SELECT received_ts FROM events - WHERE stream_ordering >= ? - LIMIT 1 - """ - - txn.execute(sql, (stream_ordering,)) - row = txn.fetchone() - if row and row[0]: - ts = row[0] - else: - ts = self.clock.time_msec() - - return ts - - return self.db_pool.runInteraction( - "get_approximate_received_ts", _get_approximate_received_ts_txn - ) - @defer.inlineCallbacks def get_event( self, @@ -923,36 +887,6 @@ class EventsWorkerStore(SQLBaseStore): ) return results - def _get_total_state_event_counts_txn(self, txn, room_id): - """ - See get_total_state_event_counts. - """ - # We join against the events table as that has an index on room_id - sql = """ - SELECT COUNT(*) FROM state_events - INNER JOIN events USING (room_id, event_id) - WHERE room_id=? - """ - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_total_state_event_counts(self, room_id): - """ - Gets the total number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.db_pool.runInteraction( - "get_total_state_event_counts", - self._get_total_state_event_counts_txn, - room_id, - ) - def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. @@ -1222,97 +1156,6 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ @@ -1357,14 +1200,3 @@ class EventsWorkerStore(SQLBaseStore): return self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index fd213d2dfd..9f691e5792 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -157,24 +157,3 @@ class PresenceStore(SQLBaseStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 402ae25571..7965a52e30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1345,43 +1345,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "validate_threepid_session_txn", validate_threepid_session_txn ) - def upsert_threepid_validation_session( - self, - medium, - address, - client_secret, - send_attempt, - session_id, - validated_at=None, - ): - """Upsert a threepid validation session - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - session_id (str): The id of this validation session - validated_at (int|None): The unix timestamp in milliseconds of - when the session was marked as valid - """ - insertion_values = { - "medium": medium, - "address": address, - "client_secret": client_secret, - } - - if validated_at: - insertion_values["validated_at"] = validated_at - - return self.db_pool.simple_upsert( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values=insertion_values, - desc="upsert_threepid_validation_session", - ) - def start_or_continue_validation_session( self, medium, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index f4008e6221..aef08c7e12 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -35,10 +35,6 @@ from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - RatelimitOverride = collections.namedtuple( "RatelimitOverride", ("messages_per_second", "burst_count") ) diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql new file mode 100644 index 0000000000..15421b99ac --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This table is no longer used. +DROP TABLE IF EXISTS presence_allow_inbound; -- cgit 1.5.1 From 050e20e7ca56c3a5985fdcf64012800c153260f2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 12:18:01 -0400 Subject: Convert some of the general database methods to async (#8100) --- changelog.d/8100.misc | 1 + synapse/storage/database.py | 23 ++++++++----------- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/events_worker.py | 16 +++++++------ synapse/storage/databases/main/registration.py | 8 +++---- synapse/storage/databases/main/roommember.py | 4 ++-- tests/handlers/test_profile.py | 4 ++-- tests/handlers/test_typing.py | 2 +- tests/storage/test_appservice.py | 16 +++++++++---- tests/storage/test_base.py | 16 ++++++++----- tests/storage/test_event_push_actions.py | 30 +++++++++++++------------ tests/storage/test_main.py | 2 +- tests/storage/test_profile.py | 4 ++-- 13 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8100.misc (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8100.misc b/changelog.d/8100.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8100.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4ada6f5563..8a9e06efcf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -332,8 +332,7 @@ class DatabasePool(object): """ return self._db_pool.running - @defer.inlineCallbacks - def _check_safe_to_upsert(self): + async def _check_safe_to_upsert(self): """ Is it safe to use native UPSERT? @@ -342,7 +341,7 @@ class DatabasePool(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self.simple_select_list( + updates = await self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -614,8 +613,7 @@ class DatabasePool(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -631,7 +629,7 @@ class DatabasePool(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) + await self.runInteraction(desc, self.simple_insert_txn, table, values) except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -684,8 +682,7 @@ class DatabasePool(object): txn.executemany(sql, vals) - @defer.inlineCallbacks - def simple_upsert( + async def simple_upsert( self, table, keyvalues, @@ -714,14 +711,14 @@ class DatabasePool(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(None or bool): Native upserts always return None. Emulated + None or bool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ attempts = 0 while True: try: - result = yield self.runInteraction( + return await self.runInteraction( desc, self.simple_upsert_txn, table, @@ -730,7 +727,6 @@ class DatabasePool(object): insertion_values, lock=lock, ) - return result except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: @@ -1121,8 +1117,7 @@ class DatabasePool(object): return cls.cursor_to_dict(txn) - @defer.inlineCallbacks - def simple_select_many_batch( + async def simple_select_many_batch( self, table, column, @@ -1156,7 +1151,7 @@ class DatabasePool(object): it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: - rows = yield self.runInteraction( + rows = await self.runInteraction( desc, self.simple_select_many_txn, table, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5cf1a88399..02568a2391 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore( service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. Returns: - A Deferred which resolves when the state was set successfully. + An Awaitable which resolves when the state was set successfully. """ return self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5687448e3d..8c63a0dc4d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", + rows = yield defer.ensureDeferred( + self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) ) return {r["event_id"] for r in rows} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de50fa6e94..068ad22b30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,7 @@ import logging import re -from typing import Dict, List, Optional - -from twisted.internet.defer import Deferred +from typing import Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore): id_server (str) Returns: - Deferred + Awaitable """ # We need to use an upsert, in case they user had already bound the # threepid @@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: + ) -> Awaitable: """Record a mapping from an external user id to a mxid Args: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1cc8c08ed0..161edbeccb 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db_pool.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d70e1fc608..b609b30d4a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() self.hs = hs @@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") + yield defer.ensureDeferred(self.store.create_profile("caroline")) yield self.store.set_profile_displayname("caroline", "Caroline") response = yield defer.ensureDeferred( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64afd581bc..e01de158e5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 98b74890d5..a425e66f37 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" @@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index efcaeef1e7..13bcac743a 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", values={"columname": "Value"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", values={"columname": "Value"} + ) ) self.mock_txn.execute.assert_called_with( @@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 857db071d4..238bad5b45 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db_pool.simple_insert( - "events", - { - "stream_ordering": so, - "received_ts": ts, - "event_id": "event%i" % so, - "type": "", - "room_id": "", - "content": "", - "processed": True, - "outlier": False, - "topological_ordering": 0, - "depth": 0, - }, + return defer.ensureDeferred( + self.store.db_pool.simple_insert( + "events", + { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }, + ) ) # start with the base case where there are no events in the table diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index ab0df5ea93..fbf8af940a 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") - yield self.store.create_profile(self.user.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield self.store.set_profile_displayname(self.user.localpart, self.displayname) users, total = yield self.store.get_users_paginate( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9b6f7211ae..9d5b8aa47d 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") @@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here" -- cgit 1.5.1 From f40645e60b9cab69c953094848be61c0989a91cb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 16:20:49 -0400 Subject: Convert events worker database to async/await. (#8071) --- changelog.d/8071.misc | 1 + synapse/event_auth.py | 2 +- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 6 +- synapse/handlers/room_member.py | 2 +- synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 30 +++-- synapse/storage/databases/main/events_worker.py | 132 ++++++++++++--------- synapse/storage/databases/main/stream.py | 1 - .../test_resource_limits_server_notices.py | 6 +- tests/storage/test_appservice.py | 3 +- 12 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 changelog.d/8071.misc (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee62..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb7..5b270228e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler): auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 532fc30681..b999d91d1a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -960,7 +960,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1080,8 +1080,8 @@ class EventCreationHandler(object): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..aa1ccde211 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -716,7 +716,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..4d9b13ac04 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ class SpamCheckerApi(object): state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d3884667..dba8d91eef 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ class StateResolutionStore(object): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 431bd76693..4826be630c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8c63a0dc4d..e3a154a527 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - @defer.inlineCallbacks - def get_event( + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield defer.ensureDeferred( - self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) + rows = await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", ) return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4377bddb8c..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: - """Get new room events in stream ordering since `from_key`. Args: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1 From eebf52be060876ff14bbcbbc86b64ff9965b3622 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 19 Aug 2020 07:26:03 -0400 Subject: Be stricter about JSON that is accepted by Synapse (#8106) --- changelog.d/8106.bugfix | 1 + synapse/api/errors.py | 6 +++--- synapse/federation/federation_server.py | 5 ++--- synapse/federation/sender/transaction_manager.py | 5 ++--- synapse/handlers/e2e_keys.py | 8 ++++---- synapse/handlers/identity.py | 5 ++--- synapse/handlers/message.py | 5 +++-- synapse/handlers/oidc_handler.py | 6 +++--- synapse/handlers/ui_auth/checkers.py | 5 ++--- synapse/http/client.py | 11 ++++++----- synapse/http/federation/well_known_resolver.py | 5 ++--- synapse/http/servlet.py | 5 ++--- synapse/logging/opentracing.py | 7 +++++-- synapse/replication/tcp/commands.py | 12 +++++------- synapse/rest/client/v1/room.py | 11 +++++++---- synapse/rest/client/v2_alpha/sync.py | 5 ++--- synapse/rest/key/v2/remote_key_resource.py | 8 +++++--- synapse/storage/_base.py | 7 +++---- synapse/storage/databases/main/events_worker.py | 16 ++++++++++++++-- synapse/util/__init__.py | 14 ++++++++++++-- 20 files changed, 85 insertions(+), 62 deletions(-) create mode 100644 changelog.d/8106.bugfix (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8106.bugfix b/changelog.d/8106.bugfix new file mode 100644 index 0000000000..c46c60448f --- /dev/null +++ b/changelog.d/8106.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where invalid JSON would be accepted by Synapse. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6e40630ab6..a3f314118a 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -21,10 +21,10 @@ import typing from http import HTTPStatus from typing import Dict, List, Optional, Union -from canonicaljson import json - from twisted.web import http +from synapse.util import json_decoder + if typing.TYPE_CHECKING: from synapse.types import JsonDict @@ -593,7 +593,7 @@ class HttpResponseException(CodeMessageException): # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text try: - j = json.loads(self.response.decode("utf-8")) + j = json_decoder.decode(self.response.decode("utf-8")) except ValueError: j = {} diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 11c5d63298..630f571cd4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,7 +28,6 @@ from typing import ( Union, ) -from canonicaljson import json from prometheus_client import Counter, Histogram from twisted.internet import defer @@ -63,7 +62,7 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, unwrapFirstError +from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -551,7 +550,7 @@ class FederationServer(FederationBase): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_str) + key_id: json_decoder.decode(json_str) } logger.info( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index c7f6cb3d73..9bd534a313 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -15,8 +15,6 @@ import logging from typing import TYPE_CHECKING, List, Tuple -from canonicaljson import json - from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -28,6 +26,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.util import json_decoder from synapse.util.metrics import measure_func if TYPE_CHECKING: @@ -71,7 +70,7 @@ class TransactionManager(object): for edu in pending_edus: context = edu.get_context() if context: - span_contexts.append(extract_text_map(json.loads(context))) + span_contexts.append(extract_text_map(json_decoder.decode(context))) if keep_destination: edu.strip_context() diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 84169c1022..d8def45e38 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -19,7 +19,7 @@ import logging from typing import Dict, List, Optional, Tuple import attr -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from signedjson.key import VerifyKey, decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 @@ -35,7 +35,7 @@ from synapse.types import ( get_domain_from_id, get_verify_key_from_cross_signing_key, ) -from synapse.util import unwrapFirstError +from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -404,7 +404,7 @@ class E2eKeysHandler(object): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) + key_id: json_decoder.decode(json_bytes) } @trace @@ -1186,7 +1186,7 @@ def _exception_to_failure(e): def _one_time_keys_match(old_key_json, new_key): - old_key = json.loads(old_key_json) + old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly if not isinstance(old_key, dict) or not isinstance(new_key, dict): diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 92b7404706..0ce6ddfbe4 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,8 +21,6 @@ import logging import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple -from canonicaljson import json - from twisted.internet.error import TimeoutError from synapse.api.errors import ( @@ -34,6 +32,7 @@ from synapse.api.errors import ( from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, Requester +from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -177,7 +176,7 @@ class IdentityHandler(BaseHandler): except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: - data = json.loads(e.msg) # XXX WAT? + data = json_decoder.decode(e.msg) # XXX WAT? return data logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b999d91d1a..c955a86be0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -17,7 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet.interfaces import IDelayedCall @@ -55,6 +55,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -864,7 +865,7 @@ class EventCreationHandler(object): # Ensure that we can round trip before trying to persist in db try: dump = frozendict_json_encoder.encode(event.content) - json.loads(dump) + json_decoder.decode(dump) except Exception: logger.exception("Failed to encode content: %r", event.content) raise diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 87d28a7ae9..dd3703cbd2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode @@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -367,7 +367,7 @@ class OidcHandler: # and check for an error field. If not, we respond with a generic # error message. try: - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) error = resp["error"] description = resp.get("error_description", error) except (ValueError, KeyError): @@ -384,7 +384,7 @@ class OidcHandler: # Since it is a not a 5xx code, body should be a valid JSON. It will # raise if not. - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) if "error" in resp: error = resp["error"] diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index a011e9fe29..9146dc1a3b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -16,13 +16,12 @@ import logging from typing import Any -from canonicaljson import json - from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError from synapse.config.emailconfig import ThreepidBehaviour +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response - resp_body = json.loads(data.decode("utf-8")) + resp_body = json_decoder.decode(data.decode("utf-8")) if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly diff --git a/synapse/http/client.py b/synapse/http/client.py index 8aeb70cdec..dad01a8e56 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -19,7 +19,7 @@ import urllib from io import BytesIO import treq -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from netaddr import IPAddress from prometheus_client import Counter from zope.interface import implementer, provider @@ -47,6 +47,7 @@ from synapse.http import ( from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred logger = logging.getLogger(__name__) @@ -391,7 +392,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body @@ -433,7 +434,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body @@ -463,7 +464,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) body = await self.get_raw(uri, args, headers=headers) - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) async def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. @@ -506,7 +507,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 89a3b041ce..f794315deb 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import random import time @@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers from synapse.logging.context import make_deferred_yieldable -from synapse.util import Clock +from synapse.util import Clock, json_decoder from synapse.util.caches.ttlcache import TTLCache from synapse.util.metrics import Measure @@ -181,7 +180,7 @@ class WellKnownResolver(object): if response.code != 200: raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode("utf-8")) + parsed_body = json_decoder.decode(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) result = parsed_body["m.server"].encode("ascii") diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index a34e5ead88..53acba56cb 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -17,9 +17,8 @@ import logging -from canonicaljson import json - from synapse.api.errors import Codes, SynapseError +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): return None try: - content = json.loads(content_bytes.decode("utf-8")) + content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 21dbd9f415..abe532d350 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -177,6 +177,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.config import ConfigError +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.http.site import SynapseRequest @@ -499,7 +500,9 @@ def start_active_span_from_edu( if opentracing is None: return _noop_context_manager() - carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {}) + carrier = json_decoder.decode(edu_content.get("context", "{}")).get( + "opentracing", {} + ) context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) _references = [ opentracing.child_of(span_context_from_string(x)) @@ -699,7 +702,7 @@ def span_context_from_string(carrier): Returns: The active span context decoded from a string. """ - carrier = json.loads(carrier) + carrier = json_decoder.decode(carrier) return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index d853e4447e..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -21,9 +21,7 @@ import abc import logging from typing import Tuple, Type -from canonicaljson import json - -from synapse.util import json_encoder as _json_encoder +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -134,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -359,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -367,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2ab30ce897..f216382636 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ import re from typing import List, Optional from urllib import parse as urlparse -from canonicaljson import json - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.util import json_decoder MYPY = False if MYPY: @@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd63..96488b131a 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -16,8 +16,6 @@ import itertools import logging -from canonicaljson import json - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken +from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet): filter_collection = DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: - filter_object = json.loads(filter_id) + filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( filter_object, self.hs.config.filter_timeline_limit ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e266204f95..5db7f81c2d 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,19 +15,19 @@ import logging from typing import Dict, Set -from canonicaljson import json from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.util import json_decoder logger = logging.getLogger(__name__) class RemoteKey(DirectServeJsonResource): - """HTTP resource for retreiving the TLS certificate and NACL signature + """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks that the NACL signature for the remote server is valid. Returns a dict of @@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource): # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) + # If there is a cache miss, request the missing keys, then recurse (and + # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: await self.fetcher.get_keys(cache_misses) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json in json_results: - key_json = json.loads(key_json.decode("utf-8")) + key_json = json_decoder.decode(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: key_json = sign_json(key_json, self.config.server_name, signing_key) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6814bf5fcf..ab49d227de 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,12 +19,11 @@ import random from abc import ABCMeta from typing import Any, Optional -from canonicaljson import json - from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.types import Collection, get_domain_from_id +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -99,13 +98,13 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # Decode it to a Unicode string before feeding it to json.loads, since + # Decode it to a Unicode string before feeding it to the JSON decoder, since # Python 3.5 does not support deserializing bytes. if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") try: - return json.loads(db_content) + return json_decoder.decode(db_content) except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e3a154a527..4a3333c0db 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -596,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - d = db_to_json(row["json"]) - internal_metadata = db_to_json(row["internal_metadata"]) + # If the event or metadata cannot be parsed, log the error and act + # as if the event is unknown. + try: + d = db_to_json(row["json"]) + except ValueError: + logger.error("Unable to parse json from event: %s", event_id) + continue + try: + internal_metadata = db_to_json(row["internal_metadata"]) + except ValueError: + logger.error( + "Unable to parse internal_metadata from event: %s", event_id + ) + continue format_version = row["format_version"] if format_version is None: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b3f76428b6..b2a22dbd5c 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -25,8 +25,18 @@ from synapse.logging import context logger = logging.getLogger(__name__) -# Create a custom encoder to reduce the whitespace produced by JSON encoding. -json_encoder = json.JSONEncoder(separators=(",", ":")) + +def _reject_invalid_json(val): + """Do not allow Infinity, -Infinity, or NaN values in JSON.""" + raise json.JSONDecodeError("Invalid JSON value: '%s'" % val) + + +# Create a custom encoder to reduce the whitespace produced by JSON encoding and +# ensure that valid JSON is produced. +json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":")) + +# Create a custom decoder to reject Python extensions to JSON. +json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) def unwrapFirstError(failure): -- cgit 1.5.1 From 318f4e738e66a1376a8db25c217a090384083f2d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 20 Aug 2020 16:42:12 +0100 Subject: Be more tolerant of membership events in unknown rooms (#8110) It turns out that not all out-of-band membership events are labelled as such, so we need to be more accepting here. --- changelog.d/8110.bugfix | 1 + synapse/events/__init__.py | 2 ++ synapse/storage/databases/main/events_worker.py | 31 ++++++++++++++++++++----- 3 files changed, 28 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8110.bugfix (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8110.bugfix b/changelog.d/8110.bugfix new file mode 100644 index 0000000000..5269a232e1 --- /dev/null +++ b/changelog.d/8110.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.12.0 which could cause `/sync` requests to fail with a 404 if you had a very old outstanding room invite. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index cc5deca75b..67db763dbf 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -133,6 +133,8 @@ class _EventInternalMetadata(object): rejection. This is needed as those events are marked as outliers, but they still need to be processed as if they're new events (e.g. updating invite state in the database, relaying to clients, etc). + + (Added in synapse 0.99.0, so may be unreliable for events received before that) """ return self._dict.get("out_of_band_membership", False) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4a3333c0db..e1241a724b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -620,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row["room_version_id"] if not room_version_id: - # this should only happen for out-of-band membership events - if not internal_metadata.get("out_of_band_membership"): - logger.warning( - "Room %s for event %s is unknown", d["room_id"], event_id + # this should only happen for out-of-band membership events which + # arrived before #6983 landed. For all other events, we should have + # an entry in the 'rooms' table. + # + # However, the 'out_of_band_membership' flag is unreliable for older + # invites, so just accept it for all membership events. + # + if d["type"] != EventTypes.Member: + raise Exception( + "Room %s for event %s is unknown" % (d["room_id"], event_id) ) - continue - # take a wild stab at the room version based on the event format + # so, assuming this is an out-of-band-invite that arrived before #6983 + # landed, we know that the room version must be v5 or earlier (because + # v6 hadn't been invented at that point, so invites from such rooms + # would have been rejected.) + # + # The main reason we need to know the room version here (other than + # choosing the right python Event class) is in case the event later has + # to be redacted - and all the room versions up to v5 used the same + # redaction algorithm. + # + # So, the following approximations should be adequate. + if format_version == EventFormatVersions.V1: + # if it's event format v1 then it must be room v1 or v2 room_version = RoomVersions.V1 elif format_version == EventFormatVersions.V2: + # if it's event format v2 then it must be room v3 room_version = RoomVersions.V3 else: + # if it's event format v3 then it must be room v4 or v5 room_version = RoomVersions.V5 else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) -- cgit 1.5.1 From 4c6c56dc58aba7af92f531655c2355d8f25e529c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Aug 2020 07:19:32 -0400 Subject: Convert simple_select_one and simple_select_one_onecol to async (#8162) --- changelog.d/8162.misc | 1 + synapse/storage/database.py | 36 +++++++++++--- synapse/storage/databases/main/devices.py | 14 +++--- synapse/storage/databases/main/directory.py | 4 +- synapse/storage/databases/main/e2e_room_keys.py | 8 ++-- synapse/storage/databases/main/events_worker.py | 10 ++-- synapse/storage/databases/main/group_server.py | 18 ++++--- synapse/storage/databases/main/media_repository.py | 13 +++-- .../storage/databases/main/monthly_active_users.py | 15 +++--- synapse/storage/databases/main/profile.py | 17 ++++--- synapse/storage/databases/main/receipts.py | 6 ++- synapse/storage/databases/main/registration.py | 10 ++-- synapse/storage/databases/main/rejections.py | 5 +- synapse/storage/databases/main/room.py | 10 ++-- synapse/storage/databases/main/state.py | 4 +- synapse/storage/databases/main/stats.py | 10 ++-- synapse/storage/databases/main/user_directory.py | 9 ++-- tests/handlers/test_profile.py | 56 +++++++++++++++++----- tests/handlers/test_typing.py | 4 +- tests/module_api/test_api.py | 2 +- tests/storage/test_base.py | 28 ++++++----- tests/storage/test_devices.py | 8 ++-- tests/storage/test_profile.py | 23 +++++++-- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 20 ++++++-- 25 files changed, 220 insertions(+), 113 deletions(-) create mode 100644 changelog.d/8162.misc (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8162.misc b/changelog.d/8162.misc new file mode 100644 index 0000000000..e26764dea1 --- /dev/null +++ b/changelog.d/8162.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bc327e344e..181c3ec249 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,9 +29,11 @@ from typing import ( Tuple, TypeVar, Union, + overload, ) from prometheus_client import Histogram +from typing_extensions import Literal from twisted.enterprise import adbapi from twisted.internet import defer @@ -1020,14 +1022,36 @@ class DatabasePool(object): return txn.execute_batch(sql, args) - def simple_select_one( + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one", + ) -> Dict[str, Any]: + ... + + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: + ... + + async def simple_select_one( self, table: str, keyvalues: Dict[str, Any], retcols: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> defer.Deferred: + ) -> Optional[Dict[str, Any]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1038,18 +1062,18 @@ class DatabasePool(object): allow_none: If true, return None instead of failing if the SELECT statement returns no rows """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def simple_select_one_onecol( + async def simple_select_one_onecol( self, table: str, keyvalues: Dict[str, Any], retcol: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one_onecol", - ) -> defer.Deferred: + ) -> Optional[Any]: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -1061,7 +1085,7 @@ class DatabasePool(object): statement returns no rows desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_onecol_txn, table, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 03b45dbc4d..a811a39eb5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id: str, device_id: str): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - defer.Deferred for a dict containing the device information + A dict containing the device information Raises: StoreError: if the device is not found """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id: str): + async def get_device_list_last_stream_id_for_remote( + self, user_id: str + ) -> Optional[Any]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 037e02603c..301d5d845a 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) - def get_room_alias_creator(self, room_alias): - return self.db_pool.simple_select_one_onecol( + async def get_room_alias_creator(self, room_alias: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 2eeb9f97dc..46c3e33cc6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): return ret - def count_e2e_room_keys(self, user_id, version): + async def count_e2e_room_keys(self, user_id: str, version: str) -> int: """Get the number of keys in a backup version. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e1241a724b..d59d73938a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def get_received_ts(self, event_id): + async def get_received_ts(self, event_id: str) -> Optional[int]: """Get received_ts (when it was persisted) for the event. Raises an exception for unknown events. Args: - event_id (str) + event_id: The event ID to query. Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. + Timestamp in milliseconds, or None for events that were persisted + before received_ts was implemented. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a488e0924b..c39864f59f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = "" class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db_pool.simple_select_one( + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore): ) return bool(result) - def is_user_admin_in_group(self, group_id, user_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_admin_in_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore): desc="is_user_admin_in_group", ) - def is_user_invited_to_local_group(self, group_id, user_id): + async def is_user_invited_to_local_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: """Has the group server invited a user? """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 80fc1cd009..4ae255ebd8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - def get_local_media(self, media_id): + async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media + Returns: None if the media_id doesn't exist. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_thumbnail", ) - def get_cached_remote_media(self, origin, media_id): - return self.db_pool.simple_select_one( + async def get_cached_remote_media( + self, origin, media_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e71cdd2cb4..fe30552c08 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): + async def user_last_seen_monthly_active(self, user_id: str) -> int: """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen + Checks if a given user is part of the monthly active user group + Arguments: + user_id: user to add/update + + Return: + Timestamp since last seen, None if never seen """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8261357d4..b8233c4848 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - async def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", @@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - def get_profile_displayname(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_displayname(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", desc="get_profile_displayname", ) - def get_profile_avatar_url(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_avatar_url(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", desc="get_profile_avatar_url", ) - def get_from_remote_profile_cache(self, user_id): - return self.db_pool.simple_select_one( + async def get_from_remote_profile_cache( + self, user_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6821476ee0..cea5ac9a68 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db_pool.simple_select_one_onecol( + async def get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 321a51cc6a..eced53d470 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - def get_user_by_id(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py index cf9ba51205..1e361aaa9a 100644 --- a/synapse/storage/databases/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore @@ -21,8 +22,8 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def get_rejection_reason(self, event_id: str) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index b3772be2b2..97ecdb16e4 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore): self.config = hs.config - def get_room(self, room_id): + async def get_room(self, room_id: str) -> dict: """Retrieve a room. Args: - room_id (str): The ID of the room to retrieve. + room_id: The ID of the room to retrieve. Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore): return ret_val @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db_pool.simple_select_one_onecol( + async def is_room_blocked(self, room_id: str) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 991233a9bc..458f169617 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 802c9019b9..9fe97af56a 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore): return len(rooms_to_work_on) - def get_stats_positions(self): + async def get_stats_positions(self) -> int: """ Returns the stats processor positions. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - def get_earliest_token_for_stats(self, stats_type, id): + async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore): being calculated. Returns: - Deferred[int] + The earliest token. """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af21fe457a..20cbcd851c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,6 +15,7 @@ import logging import re +from typing import Any, Dict, Optional from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - def get_user_in_directory(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - def get_user_directory_stream_pos(self): - return self.db_pool.simple_select_one_onecol( + async def get_user_directory_stream_pos(self) -> int: + return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index b609b30d4a..60ebc95f3e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) displayname = yield defer.ensureDeferred( self.handler.get_displayname(self.frank) @@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) @defer.inlineCallbacks @@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) # Setting displayname a second time is forbidden @@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): yield defer.ensureDeferred(self.store.create_profile("caroline")) - yield self.store.set_profile_displayname("caroline", "Caroline") + yield defer.ensureDeferred( + self.store.set_profile_displayname("caroline", "Caroline") + ) response = yield defer.ensureDeferred( self.query_handlers["profile"]( @@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/pic.gif", ) @@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) @@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index e01de158e5..834b4a0af6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_users_in_room = get_users_in_room - self.datastore.get_user_directory_stream_pos.return_value = ( + self.datastore.get_user_directory_stream_pos.side_effect = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam - defer.succeed(1) + lambda: make_awaitable(1) ) self.datastore.get_current_state_deltas.return_value = (0, None) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 807cd65dd6..04de0b9dbe 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the new user exists with all provided attributes self.assertEqual(user_id, "@bob:test") self.assertTrue(access_token) - self.assertTrue(self.store.get_user_by_id(user_id)) + self.assertTrue(self.get_success(self.store.get_user_by_id(user_id))) # Check that the email was assigned emails = self.get_success(self.store.user_get_threepids(user_id)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 13bcac743a..bf22540d99 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db_pool.simple_select_one_onecol( - table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + value = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one_onecol( + table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + ) ) self.assertEquals("Value", value) @@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"], + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"], + ) ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) @@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True, + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True, + ) ) self.assertFalse(ret) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 87ed8f8cd1..34ae8c9da7 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name 1") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # check it worked - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) @defer.inlineCallbacks diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9d5b8aa47d..3fd0a38cf5 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase): def test_displayname(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + ) self.assertEquals( - "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ), ) @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.u_frank.localpart, "http://my.site/here" + ) ) self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ), ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 58f827d8d3..70c55cd650 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "user_type": None, "deactivated": 0, }, - (yield self.store.get_user_by_id(self.user_id)), + (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index d07b985a8e..bc8400f240 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield self.store.get_room(self.room.to_string())), + (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), ) @defer.inlineCallbacks def test_get_room_unknown_room(self): - self.assertIsNone((yield self.store.get_room("!uknown:test")),) + self.assertIsNone( + (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) + ) @defer.inlineCallbacks def test_get_room_with_stats(self): @@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "public": True, }, - (yield self.store.get_room_with_stats(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks def test_get_room_with_stats_unknown_room(self): - self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats("!uknown:test") + ) + ), + ) class RoomEventsStoreTestCase(unittest.TestCase): -- cgit 1.5.1 From e3c91a3c5593c52298a7c511737d3e0eec4135ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 Aug 2020 13:15:20 +0100 Subject: Make SlavedIdTracker.advance have same interface as MultiWriterIDGenerator (#8171) --- changelog.d/8171.misc | 1 + synapse/replication/slave/storage/_slaved_id_tracker.py | 4 ++-- synapse/replication/slave/storage/account_data.py | 4 ++-- synapse/replication/slave/storage/deviceinbox.py | 2 +- synapse/replication/slave/storage/devices.py | 4 ++-- synapse/replication/slave/storage/groups.py | 2 +- synapse/replication/slave/storage/presence.py | 2 +- synapse/replication/slave/storage/push_rule.py | 2 +- synapse/replication/slave/storage/pushers.py | 2 +- synapse/replication/slave/storage/receipts.py | 2 +- synapse/replication/slave/storage/room.py | 2 +- synapse/storage/databases/main/events_worker.py | 4 ++-- 12 files changed, 16 insertions(+), 15 deletions(-) create mode 100644 changelog.d/8171.misc (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8171.misc b/changelog.d/8171.misc new file mode 100644 index 0000000000..cafbf23d83 --- /dev/null +++ b/changelog.d/8171.misc @@ -0,0 +1 @@ +Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`. diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index d43eaf3a29..047f2c50f7 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -21,9 +21,9 @@ class SlavedIdTracker(object): self.step = step self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self.advance(_load_current_id(db_conn, table, column)) + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, new_id): + def advance(self, instance_name, new_id): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self): diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 154f0e687c..bb66ba9b80 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(token) + self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index ee7f69a918..533d927701 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(token) + self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 722f3745e9..596c72eb92 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -50,10 +50,10 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 3291558c7a..567b4a5cc1 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == GroupServerStream.NAME: - self._group_updates_id_gen.advance(token) + self._group_updates_id_gen.advance(instance_name, token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index a912c04360..025f6f6be8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(token) + self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 90d90833f9..de904c943c 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -30,7 +30,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) if stream_name == PushRulesStream.NAME: - self._push_rules_stream_id_gen.advance(token) + self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 63300e5da6..9da218bfe8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(token) + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 17ba1f22ac..5c2986e050 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(token) + self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 427c81772b..80ae803ad9 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == PublicRoomsStream.NAME: - self._public_room_id_gen.advance(token) + self._public_room_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d59d73938a..e6247d682d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -113,9 +113,9 @@ class EventsWorkerStore(SQLBaseStore): def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(token) + self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(-token) + self._backfill_id_gen.advance(instance_name, -token) super().process_replication_rows(stream_name, instance_name, token, rows) -- cgit 1.5.1 From 54f8d73c005cf0401d05fc90e857da253f9d1168 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:21:48 -0400 Subject: Convert additional databases to async/await (#8199) --- changelog.d/8199.misc | 1 + synapse/storage/databases/main/__init__.py | 50 +++++++----- synapse/storage/databases/main/devices.py | 38 +++++---- synapse/storage/databases/main/events_worker.py | 48 ++++++----- synapse/storage/databases/main/purge_events.py | 30 +++---- synapse/storage/databases/main/receipts.py | 14 ++-- synapse/storage/databases/main/relations.py | 103 +++++++++++------------- 7 files changed, 147 insertions(+), 137 deletions(-) create mode 100644 changelog.d/8199.misc (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8199.misc b/changelog.d/8199.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8199.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index e6536c8456..99890ffbf3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ import calendar import logging import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -294,16 +294,16 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - def count_daily_users(self): + async def count_daily_users(self) -> int: """ Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_users", self._count_users, yesterday ) - def count_monthly_users(self): + async def count_monthly_users(self) -> int: """ Counts the number of users who used this homeserver in the last 30 days. Note this method is intended for phonehome metrics only and is different @@ -311,7 +311,7 @@ class DataStore( amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -330,15 +330,15 @@ class DataStore( (count,) = txn.fetchone() return count - def count_r30_users(self): + async def count_r30_users(self) -> Dict[str, int]: """ Counts the number of 30 day retained users, defined as:- * Users who have created their accounts more than 30 days ago * Where last seen at most 30 days ago * Where account creation and last_seen are > 30 days apart - Returns counts globaly for a given user as well as breaking - by platform + Returns: + A mapping of counts globally as well as broken out by platform. """ def _count_r30_users(txn): @@ -411,7 +411,7 @@ class DataStore( return results - return self.db_pool.runInteraction("count_r30_users", _count_r30_users) + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -421,7 +421,7 @@ class DataStore( today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 - def generate_user_daily_visits(self): + async def generate_user_daily_visits(self) -> None: """ Generates daily visit data for use in cohort/ retention analysis """ @@ -476,7 +476,7 @@ class DataStore( # frequently self._last_user_visit_update = now - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) @@ -500,22 +500,28 @@ class DataStore( desc="get_users", ) - def get_users_paginate( - self, start, limit, user_id=None, name=None, guests=True, deactivated=False - ): + async def get_users_paginate( + self, + start: int, + limit: int, + user_id: Optional[str] = None, + name: Optional[str] = None, + guests: bool = True, + deactivated: bool = False, + ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: - start (int): start number to begin the query from - limit (int): number of rows to retrieve - user_id (string): search for user_id. ignored if name is not None - name (string): search for local part of user_id or display name - guests (bool): whether to in include guest users - deactivated (bool): whether to include deactivated users + start: start number to begin the query from + limit: number of rows to retrieve + user_id: search for user_id. ignored if name is not None + name: search for local part of user_id or display name + guests: whether to in include guest users + deactivated: whether to include deactivated users Returns: - defer.Deferred: resolves to list[dict[str, Any]], int + A tuple of a list of mappings from user to information and a count of total users. """ def get_users_paginate_txn(txn): @@ -558,7 +564,7 @@ class DataStore( users = self.db_pool.cursor_to_dict(txn) return users, count - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_paginate_txn", get_users_paginate_txn ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e8379c73c4..a29157d979 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -313,9 +313,9 @@ class DeviceWorkerStore(SQLBaseStore): return results - def _get_last_device_update_for_remote_user( + async def _get_last_device_update_for_remote_user( self, destination: str, user_id: str, from_stream_id: int - ): + ) -> int: def f(txn): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id @@ -326,12 +326,16 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) + return await self.db_pool.runInteraction( + "get_last_device_update_for_remote_user", f + ) - def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): + async def mark_as_sent_devices_by_remote( + self, destination: str, stream_id: int + ) -> None: """Mark that updates have successfully been sent to the destination. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -684,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): + async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user. """ @@ -698,7 +702,7 @@ class DeviceWorkerStore(SQLBaseStore): txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_remote_user_device_list_as_unsubscribed", _mark_remote_user_device_list_as_unsubscribed_txn, ) @@ -959,9 +963,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): desc="update_device", ) - def update_remote_device_list_cache_entry( + async def update_remote_device_list_cache_entry( self, user_id: str, device_id: str, content: JsonDict, stream_id: int - ): + ) -> None: """Updates a single device in the cache of a remote user's devicelist. Note: assumes that we are the only thread that can be updating this user's @@ -972,11 +976,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: ID of decivice being updated content: new data on this device stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -1028,9 +1029,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - def update_remote_device_list_cache( + async def update_remote_device_list_cache( self, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's @@ -1040,11 +1041,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: User to update device list for devices: list of device objects supplied over federation stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -1054,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def _update_remote_device_list_cache_txn( self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e6247d682d..a7a73cc3d8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -823,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - def _maybe_redact_event_row(self, original_ev, redactions, event_map): + def _maybe_redact_event_row( + self, + original_ev: EventBase, + redactions: Iterable[str], + event_map: Dict[str, EventBase], + ) -> Optional[EventBase]: """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. Args: - original_ev (EventBase): - redactions (iterable[str]): list of event ids of potential redaction events - event_map (dict[str, EventBase]): other events which have been fetched, in - which we can look up the redaaction events. Map from event id to event. + original_ev: The original event. + redactions: list of event ids of potential redaction events + event_map: other events which have been fetched, in which we can + look up the redaaction events. Map from event id to event. Returns: - Deferred[EventBase|None]: if the event should be redacted, a pruned - event object. Otherwise, None. + If the event should be redacted, a pruned event object. Otherwise, None. """ if original_ev.type == "m.room.create": # we choose to ignore redactions of m.room.create events. @@ -946,17 +950,17 @@ class EventsWorkerStore(SQLBaseStore): row = txn.fetchone() return row[0] if row else 0 - def get_current_state_event_counts(self, room_id): + async def get_current_state_event_counts(self, room_id: str) -> int: """ Gets the current number of state events in a room. Args: - room_id (str) + room_id: The room ID to query. Returns: - Deferred[int] + The current number of state events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, @@ -991,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() - def get_all_new_forward_event_rows(self, last_id, current_id, limit): + async def get_all_new_forward_event_rows( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -999,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore): current_id: the maximum stream_id to return up to limit: the maximum number of rows to return - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1020,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) - def get_ex_outlier_stream_rows(self, last_id, current_id): + async def get_ex_outlier_stream_rows( + self, last_id: int, current_id: int + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1054,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) @@ -1226,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore): return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_next_event_to_expire(self): + async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry table, or None if there's no more event to expire. - Returns: Deferred[Optional[Tuple[str, int]]] + Returns: A tuple containing the event ID as its first element and an expiry timestamp as its second one, if there's at least one row in the event_expiry table. None otherwise. @@ -1246,6 +1254,6 @@ class EventsWorkerStore(SQLBaseStore): return txn.fetchone() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 3526b6fd66..ea833829ae 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Tuple +from typing import Any, List, Set, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -25,25 +25,24 @@ logger = logging.getLogger(__name__) class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> Set[int]: """Deletes room history before a certain point Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): + room_id: + token: A topological token to delete events before + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. + The set of state groups that are referenced by deleted events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): return referenced_state_groups - def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> List[int]: """Deletes all record of a room Args: - room_id (str) + room_id Returns: - Deferred[List[int]]: The list of state groups to delete. + The list of state groups to delete. """ - - return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) + return await self.db_pool.runInteraction( + "purge_room", self._purge_room_txn, room_id + ) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 436f22ad2d..4a0d5a320e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results - def get_users_sent_receipts_between(self, last_id: int, current_id: int): + async def get_users_sent_receipts_between( + self, last_id: int, current_id: int + ) -> List[str]: """Get all users who sent receipts between `last_id` exclusive and `current_id` inclusive. Returns: - Deferred[List[str]] + The list of users. """ if last_id == current_id: @@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return [r[0] for r in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn ) @@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore): return stream_id, max_persisted_id - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.db_pool.runInteraction( + async def insert_graph_receipt( + self, room_id, receipt_type, user_id, event_ids, data + ): + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index a9ceffc20e..5cd61547f7 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -34,38 +34,33 @@ logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) - def get_relations_for_event( + async def get_relations_for_event( self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[RelationPaginationToken] = None, + to_token: Optional[RelationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. + event_id: Fetch events that relate to this event ID. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. + List of event IDs that match relations requested. The rows are of + the form `{"event_id": "..."}`. """ where_clause = ["relates_to_id = ?"] @@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @cached(tree=True) - def get_aggregation_groups_for_event( + async def get_aggregation_groups_for_event( self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + event_type: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[AggregationPaginationToken] = None, + to_token: Optional[AggregationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore): on an event. Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or + event_id: Fetch events that relate to this event ID. + event_type: Only fetch events with this event type, if given. + limit: Only fetch the `limit` groups. + direction: Whether to fetch the highest count first (`"b"`) or the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. + List of groups of annotations that match. Each row is a dict with + `type`, `key` and `count` fields. """ where_clause = ["relates_to_id = ?", "relation_type = ?"] @@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore): return await self.get_event(edit_id, allow_none=True) - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + async def has_user_annotated_event( + self, parent_id: str, event_type: str, aggregation_key: str, sender: str + ) -> bool: """Check if a user has already annotated an event with the same key (e.g. already liked an event). Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation + parent_id: The event being annotated + event_type: The event type of the annotation + aggregation_key: The aggregation key of the annotation + sender: The sender of the annotation Returns: - Deferred[bool] + True if the event is already annotated. """ sql = """ @@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore): return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) -- cgit 1.5.1 From 82c1ee1c22a87b9e6e3179947014b0f11c0a1ac3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Sep 2020 15:48:37 +0100 Subject: Add experimental support for sharding event persister. (#8170) This is *not* ready for production yet. Caveats: 1. We should write some tests... 2. The stream token that we use for events can get stalled at the minimum position of all writers. This means that new events may not be processed and e.g. sent down sync streams if a writer isn't writing or is slow. --- changelog.d/8170.feature | 1 + synapse/config/_base.py | 21 ++++++- synapse/config/_base.pyi | 1 + synapse/config/workers.py | 37 ++++++++---- synapse/handlers/federation.py | 44 ++++++++++----- synapse/handlers/message.py | 14 +++-- synapse/handlers/room.py | 14 +++-- synapse/handlers/room_member.py | 7 --- synapse/replication/http/federation.py | 12 +++- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/events_worker.py | 66 +++++++++++++++------- .../schema/delta/58/14events_instance_name.sql | 16 ++++++ .../delta/58/14events_instance_name.sql.postgres | 26 +++++++++ synapse/storage/util/id_generators.py | 10 ++-- 18 files changed, 206 insertions(+), 77 deletions(-) create mode 100644 changelog.d/8170.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8170.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1417487427..73f0717b0d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 16389a0dca..bd8efbb768 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -923,7 +923,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1216,7 +1217,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1363,15 +1364,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1625,7 +1626,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1652,7 +1653,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1900,7 +1903,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1913,6 +1916,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1944,6 +1948,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1954,6 +1959,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1968,6 +1974,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2042,17 +2049,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2903,6 +2913,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2910,14 +2921,19 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72bb638167..6398774c02 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,9 +376,8 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -904,9 +903,10 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -974,7 +974,9 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9d5b1828df..55794c3057 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1260,10 +1262,10 @@ class RoomShutdownHandler(object): # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1293,7 +1295,9 @@ class RoomShutdownHandler(object): # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a7962b0ada..e5cb3c2d5c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,13 +82,6 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 6b56315148..5c8be747e1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 16c63ff4ec..3705618b4f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 0ac854aee2..c73d54fb67 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -68,7 +68,7 @@ class Databases(object): # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0b69aa6a94..4c3c162acf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6313b41eef..46b11e705b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -97,6 +97,7 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -108,7 +109,7 @@ class PersistEventsStore: # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -800,6 +801,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7a73cc3d8..17f5997b89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644 index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644 index 0000000000..97c1e6a0c5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +SELECT setval('events_backfill_stream_seq', ( + SELECT COALESCE(-MIN(stream_ordering), 1) FROM events +)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9f3d23f0a5..8fd21c2bf8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -231,8 +231,12 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] @@ -362,9 +366,7 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. -- cgit 1.5.1 From 9f8abdcc3828e942f09cc62e29dff93dbaa01ec7 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 4 Sep 2020 10:19:42 +0100 Subject: Revert "Add experimental support for sharding event persister. (#8170)" (#8242) * Revert "Add experimental support for sharding event persister. (#8170)" This reverts commit 82c1ee1c22a87b9e6e3179947014b0f11c0a1ac3. * Changelog --- changelog.d/8170.feature | 1 - changelog.d/8242.feature | 1 + synapse/config/_base.py | 21 +------ synapse/config/_base.pyi | 1 - synapse/config/workers.py | 37 ++++-------- synapse/handlers/federation.py | 44 +++++---------- synapse/handlers/message.py | 14 ++--- synapse/handlers/room.py | 14 ++--- synapse/handlers/room_member.py | 7 +++ synapse/replication/http/federation.py | 12 +--- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 4 +- synapse/storage/databases/main/events_worker.py | 66 +++++++--------------- .../schema/delta/58/14events_instance_name.sql | 16 ------ .../delta/58/14events_instance_name.sql.postgres | 26 --------- synapse/storage/util/id_generators.py | 10 ++-- 19 files changed, 78 insertions(+), 206 deletions(-) delete mode 100644 changelog.d/8170.feature create mode 100644 changelog.d/8242.feature delete mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql delete mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature deleted file mode 100644 index b363e929ea..0000000000 --- a/changelog.d/8170.feature +++ /dev/null @@ -1 +0,0 @@ -Add experimental support for sharding event persister. diff --git a/changelog.d/8242.feature b/changelog.d/8242.feature new file mode 100644 index 0000000000..f6891e360d --- /dev/null +++ b/changelog.d/8242.feature @@ -0,0 +1 @@ +Back out experimental support for sharding event persister. **PLEASE REMOVE THIS LINE FROM THE FINAL CHANGELOG** diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 73f0717b0d..1417487427 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,26 +832,11 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - # If multiple instances are not defined we always return true + + # If multiple instances are not defined we always return true. if not self.instances or len(self.instances) == 1: return True - return self.get_instance(key) == instance_name - - def get_instance(self, key: str) -> str: - """Get the instance responsible for handling the given key. - - Note: For things like federation sending the config for which instance - is sending is known only to the sender instance if there is only one. - Therefore `should_handle` should be used where possible. - """ - - if not self.instances: - return "master" - - if len(self.instances) == 1: - return self.instances[0] - # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -861,7 +846,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] + return self.instances[remainder] == instance_name __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index b8faafa9bd..eb911e8f9f 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,4 +142,3 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... - def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index f23e42cdf9..c784a71508 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,24 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union - import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def -def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: - """Helper for allowing parsing a string or list of strings to a config - option expecting a list of strings. - """ - - if isinstance(obj, str): - return [obj] - return obj - - @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -45,13 +33,11 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instances that write to the event and backfill streams. - typing: The instance that writes to the typing stream. + events: The instance that writes to the event and backfill streams. + events: The instance that writes to the typing stream. """ - events = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter - ) + events = attr.ib(default="master", type=str) typing = attr.ib(default="master", type=str) @@ -119,18 +105,15 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writers for events and typing also appears in + # Check that the configured writer for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instances = _instance_to_list_converter(getattr(self.writers, stream)) - for instance in instances: - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) - - self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + instance = getattr(self.writers, stream) + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 310c7f7138..43f2986f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -923,8 +923,7 @@ class FederationHandler(BaseHandler): ) ) - if ev_infos: - await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) + await self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1217,7 +1216,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, room_id, event_infos, + destination, event_infos, ) def _sanity_check_event(self, ev): @@ -1364,15 +1363,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, room_id, auth_chain, state, event, room_version_obj + origin, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.events_shard_config.get_instance(room_id), - "events", - max_stream_id, + self.config.worker.writers.events, "events", max_stream_id ) # Check whether this room is the result of an upgrade of a room we already know @@ -1626,7 +1625,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify(event.room_id, [(event, context)]) + await self.persist_events_and_notify([(event, context)]) return event @@ -1653,9 +1652,7 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( - event.room_id, [(event, context)] - ) + stream_id = await self.persist_events_and_notify([(event, context)]) return event, stream_id @@ -1903,7 +1900,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - event.room_id, [(event, context)], backfilled=backfilled + [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1916,7 +1913,6 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, - room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1948,7 +1944,6 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1959,7 +1954,6 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, - room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1974,7 +1968,6 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from - room_id, auth_events state event @@ -2049,20 +2042,17 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( - room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ], + ] ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify( - room_id, [(event, new_event_context)] - ) + return await self.persist_events_and_notify([(event, new_event_context)]) async def _prep_event( self, @@ -2913,7 +2903,6 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, - room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2921,19 +2910,14 @@ class FederationHandler(BaseHandler): necessary. Args: - room_id: The room ID of events being persisted. - event_and_contexts: Sequence of events with their associated - context that should be persisted. All events must belong to - the same room. + event_and_contexts: backfilled: Whether these events are a result of backfilling or not """ - instance = self.config.worker.events_shard_config.get_instance(room_id) - if instance != self._instance_name: + if self.config.worker.writers.events != self._instance_name: result = await self._send_events( - instance_name=instance, + instance_name=self.config.worker.writers.events, store=self.store, - room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6398774c02..72bb638167 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,8 +376,9 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._events_shard_config = self.config.worker.events_shard_config - self._instance_name = hs.get_instance_name() + self._is_event_writer = ( + self.config.worker.writers.events == hs.get_instance_name() + ) self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -903,10 +904,9 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. - writer_instance = self._events_shard_config.get_instance(event.room_id) - if writer_instance != self._instance_name: + if not self._is_event_writer: result = await self.send_event( - instance_name=writer_instance, + instance_name=self.config.worker.writers.events, event_id=event.event_id, store=self.store, requester=requester, @@ -974,9 +974,7 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self._events_shard_config.should_handle( - self._instance_name, event.room_id - ) + assert self._is_event_writer if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 55794c3057..9d5b1828df 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,9 +804,7 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.events_shard_config.get_instance(room_id), - "events", - last_stream_id, + self.hs.config.worker.writers.events, "events", last_stream_id ) return result, last_stream_id @@ -1262,10 +1260,10 @@ class RoomShutdownHandler(object): # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. + # + # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.events_shard_config.get_instance(new_room_id), - "events", - stream_id, + self.hs.config.worker.writers.events, "events", stream_id ) else: new_room_id = None @@ -1295,9 +1293,7 @@ class RoomShutdownHandler(object): # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.events_shard_config.get_instance(room_id), - "events", - stream_id, + self.hs.config.worker.writers.events, "events", stream_id ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e5cb3c2d5c..a7962b0ada 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,6 +82,13 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles + self._event_stream_writer_instance = hs.config.worker.writers.events + self._is_on_event_persistence_instance = ( + self._event_stream_writer_instance == hs.get_instance_name() + ) + if self._is_on_event_persistence_instance: + self.persist_event_storage = hs.get_storage().persistence + self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5c8be747e1..6b56315148 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,11 +65,10 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, room_id, event_and_contexts, backfilled): + async def _serialize_payload(store, event_and_contexts, backfilled): """ Args: store - room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -89,11 +88,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = { - "events": event_payloads, - "backfilled": backfilled, - "room_id": room_id, - } + payload = {"events": event_payloads, "backfilled": backfilled} return payload @@ -101,7 +96,6 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) - room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -126,7 +120,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - room_id, event_and_contexts, backfilled + event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b323841f73..1c303f3a46 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.get_instance_name() in hs.config.worker.writers.events: + if hs.config.worker.writers.events == hs.get_instance_name(): self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 3705618b4f..16c63ff4ec 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self._store._stream_id_gen.get_current_token_for_writer, + current_token_without_instance(self._store.get_current_events_token), self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index c73d54fb67..0ac854aee2 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -68,7 +68,7 @@ class Databases(object): # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.get_instance_name() in hs.config.worker.writers.events: + if hs.config.worker.writers.events == hs.get_instance_name(): persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4c3c162acf..0b69aa6a94 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) + raise StoreError(400, "stream_ordering too old") sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b94fe7ac17..b3d27a2ee7 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -97,7 +97,6 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() - self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -109,7 +108,7 @@ class PersistEventsStore: # This should only exist on instances that are configured to write assert ( - hs.get_instance_name() in hs.config.worker.writers.events + hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -801,7 +800,6 @@ class PersistEventsStore: table="events", values=[ { - "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 17f5997b89..a7a73cc3d8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,8 +42,7 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -79,54 +78,27 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if isinstance(database.engine, PostgresEngine): - # If we're using Postgres than we can use `MultiWriterIdGenerator` - # regardless of whether this process writes to the streams or not. - self._stream_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", - sequence_name="events_stream_seq", + if hs.config.worker.writers.events == hs.get_instance_name(): + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", ) - self._backfill_id_gen = MultiWriterIdGenerator( - db_conn=db_conn, - db=database, - instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", - sequence_name="events_backfill_stream_seq", - positive=False, + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], ) else: - # We shouldn't be running in worker mode with SQLite, but its useful - # to support it for unit tests. - # - # If this process is the writer than we need to use - # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets - # updated over replication. (Multiple writers are not supported for - # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) - else: - self._stream_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering" - ) - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql deleted file mode 100644 index 98ff76d709..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres deleted file mode 100644 index 97c1e6a0c5..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE SEQUENCE IF NOT EXISTS events_stream_seq; - -SELECT setval('events_stream_seq', ( - SELECT COALESCE(MAX(stream_ordering), 1) FROM events -)); - -CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; - -SELECT setval('events_backfill_stream_seq', ( - SELECT COALESCE(-MIN(stream_ordering), 1) FROM events -)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 8fd21c2bf8..9f3d23f0a5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -231,12 +231,8 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. - # - # We start at 1 here as a) the first generated stream ID will be 2, and - # b) other parts of the code assume that stream IDs are strictly greater - # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 1 + min(self._current_positions.values()) if self._current_positions else 0 ) self._known_persisted_positions = [] # type: List[int] @@ -366,7 +362,9 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - return self.get_persisted_upto_position() + # Currently we don't support this operation, as it's not obvious how to + # condense the stream positions of multiple writers into a single int. + raise NotImplementedError() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. -- cgit 1.5.1 From 04cc249b43e8716513f788b2a4eeb8ede24d19df Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 14 Sep 2020 10:16:41 +0100 Subject: Add experimental support for sharding event persister. Again. (#8294) This is *not* ready for production yet. Caveats: 1. We should write some tests... 2. The stream token that we use for events can get stalled at the minimum position of all writers. This means that new events may not be processed and e.g. sent down sync streams if a writer isn't writing or is slow. --- changelog.d/8294.feature | 1 + synapse/config/_base.py | 21 ++++++- synapse/config/_base.pyi | 1 + synapse/config/workers.py | 37 ++++++++---- synapse/handlers/federation.py | 44 ++++++++++----- synapse/handlers/message.py | 14 +++-- synapse/handlers/room.py | 14 +++-- synapse/handlers/room_member.py | 7 --- synapse/replication/http/federation.py | 12 +++- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 12 ++-- synapse/storage/databases/main/events_worker.py | 66 +++++++++++++++------- .../schema/delta/58/14events_instance_name.sql | 16 ++++++ .../delta/58/14events_instance_name.sql.postgres | 26 +++++++++ synapse/storage/util/id_generators.py | 10 ++-- 18 files changed, 211 insertions(+), 80 deletions(-) create mode 100644 changelog.d/8294.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8294.feature b/changelog.d/8294.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8294.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ad5ab6ad62..bb9bf8598d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c195eba830..a5734bebab 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -896,7 +896,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1189,7 +1190,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1336,15 +1337,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1593,7 +1594,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1620,7 +1621,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1868,7 +1871,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1881,6 +1884,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1912,6 +1916,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1922,6 +1927,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1936,6 +1942,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2010,17 +2017,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2871,6 +2881,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2878,14 +2889,19 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e54e2b322b..a8fe5cf4e2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,9 +376,8 @@ class EventCreationHandler: self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -902,9 +901,10 @@ class EventCreationHandler: try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -972,8 +972,10 @@ class EventCreationHandler: This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer assert self.storage.persistence is not None + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 53d85ab97d..eeade6ad3f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1259,10 +1261,10 @@ class RoomShutdownHandler: # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1292,7 +1294,9 @@ class RoomShutdownHandler: # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 100f335b80..01a6e88262 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,13 +82,6 @@ class RoomMemberHandler: self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 6b56315148..5c8be747e1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f929fc3954..ccc7ca30d8 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 985b12df91..aa5d490624 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -75,7 +75,7 @@ class Databases: # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0b69aa6a94..4c3c162acf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9cd1403b38..9a80f419e3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -32,7 +32,7 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter @@ -97,18 +97,21 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator - self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator + self._backfill_id_gen = ( + self.store._backfill_id_gen + ) # type: MultiWriterIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -809,6 +812,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7a73cc3d8..17f5997b89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644 index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644 index 0000000000..97c1e6a0c5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +SELECT setval('events_backfill_stream_seq', ( + SELECT COALESCE(-MIN(stream_ordering), 1) FROM events +)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2a66b3ad4e..1de2b91587 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -240,8 +240,12 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] @@ -398,9 +402,7 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. -- cgit 1.5.1 From 837293c314b47e988fe9532115476a6536cd6406 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 17 Sep 2020 14:37:01 +0200 Subject: Remove obsolete __future__ imports (#8337) --- changelog.d/8337.misc | 1 + contrib/cmdclient/console.py | 2 -- contrib/cmdclient/http.py | 2 -- contrib/graph/graph.py | 2 -- contrib/graph/graph3.py | 2 -- contrib/jitsimeetbridge/jitsimeetbridge.py | 2 -- contrib/scripts/kick_users.py | 8 +------- scripts-dev/definitions.py | 2 -- scripts-dev/dump_macaroon.py | 2 -- scripts-dev/federation_client.py | 2 -- scripts-dev/hash_history.py | 2 -- scripts/move_remote_media_to_new_store.py | 2 -- scripts/register_new_matrix_user | 2 -- synapse/_scripts/register_new_matrix_user.py | 2 -- synapse/app/homeserver.py | 2 -- synapse/config/emailconfig.py | 1 - synapse/config/stats.py | 2 -- synapse/storage/databases/main/events_worker.py | 2 -- synapse/util/patch_inline_callbacks.py | 2 -- 19 files changed, 2 insertions(+), 40 deletions(-) create mode 100644 changelog.d/8337.misc (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8337.misc b/changelog.d/8337.misc new file mode 100644 index 0000000000..4daf272204 --- /dev/null +++ b/changelog.d/8337.misc @@ -0,0 +1 @@ +Remove `__future__` imports related to Python 2 compatibility. \ No newline at end of file diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index dfc1d294dc..ab1e1f1f4c 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -15,8 +15,6 @@ # limitations under the License. """ Starts a synapse client console. """ -from __future__ import print_function - import argparse import cmd import getpass diff --git a/contrib/cmdclient/http.py b/contrib/cmdclient/http.py index cd3260b27d..345120b612 100644 --- a/contrib/cmdclient/http.py +++ b/contrib/cmdclient/http.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import json import urllib from pprint import pformat diff --git a/contrib/graph/graph.py b/contrib/graph/graph.py index de33fac1c7..fdbac087bd 100644 --- a/contrib/graph/graph.py +++ b/contrib/graph/graph.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import argparse import cgi import datetime diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py index 91db98e7ef..dd0c19368b 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import argparse import cgi import datetime diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index 69aa74bd34..b3de468687 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -10,8 +10,6 @@ the bridge. Requires: npm install jquery jsdom """ -from __future__ import print_function - import json import subprocess import time diff --git a/contrib/scripts/kick_users.py b/contrib/scripts/kick_users.py index 372dbd9e4f..f8e0c732fb 100755 --- a/contrib/scripts/kick_users.py +++ b/contrib/scripts/kick_users.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -from __future__ import print_function import json import sys @@ -8,11 +7,6 @@ from argparse import ArgumentParser import requests -try: - raw_input -except NameError: # Python 3 - raw_input = input - def _mkurl(template, kws): for key in kws: @@ -58,7 +52,7 @@ def main(hs, room_id, access_token, user_id_prefix, why): print("The following user IDs will be kicked from %s" % room_name) for uid in kick_list: print(uid) - doit = raw_input("Continue? [Y]es\n") + doit = input("Continue? [Y]es\n") if len(doit) > 0 and doit.lower() == "y": print("Kicking members...") # encode them all diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 9eddb6d515..15e6ce6e16 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -1,7 +1,5 @@ #! /usr/bin/python -from __future__ import print_function - import argparse import ast import os diff --git a/scripts-dev/dump_macaroon.py b/scripts-dev/dump_macaroon.py index 22b30fa78e..980b5e709f 100755 --- a/scripts-dev/dump_macaroon.py +++ b/scripts-dev/dump_macaroon.py @@ -1,7 +1,5 @@ #!/usr/bin/env python2 -from __future__ import print_function - import sys import pymacaroons diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index ad12523c4d..848a826f17 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -15,8 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import argparse import base64 import json diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index 89acb52e6a..8d6c3d24db 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import sqlite3 import sys diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index b5b63933ab..ab2e763386 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -32,8 +32,6 @@ To use, pipe the above into:: PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py """ -from __future__ import print_function - import argparse import logging import os diff --git a/scripts/register_new_matrix_user b/scripts/register_new_matrix_user index b450712ab7..8b9d30877d 100755 --- a/scripts/register_new_matrix_user +++ b/scripts/register_new_matrix_user @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - from synapse._scripts.register_new_matrix_user import main if __name__ == "__main__": diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 55cce2db22..da0996edbc 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import argparse import getpass import hashlib diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b08319ca77..dff739e106 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -15,8 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import gc import logging import math diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 72b42bfd62..cceffbfee2 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function # This file can't be called email.py because if it is, we cannot: import email.utils diff --git a/synapse/config/stats.py b/synapse/config/stats.py index 62485189ea..b559bfa411 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import division - import sys from ._base import Config diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 17f5997b89..cd3739c16c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import division - import itertools import logging import threading diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 54c046b6e1..72574d3af2 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import functools import sys from typing import Any, Callable, List -- cgit 1.5.1 From 8a4a4186ded34bab1ffb4ee1cebcb476890da207 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Sep 2020 09:56:44 -0400 Subject: Simplify super() calls to Python 3 syntax. (#8344) This converts calls like super(Foo, self) -> super(). Generated with: sed -i "" -Ee 's/super\([^\(]+\)/super()/g' **/*.py --- changelog.d/8344.misc | 1 + scripts-dev/definitions.py | 2 +- scripts-dev/federation_client.py | 2 +- synapse/api/errors.py | 50 ++++++++++------------ synapse/api/filtering.py | 2 +- synapse/app/generic_worker.py | 6 +-- synapse/appservice/api.py | 2 +- synapse/config/consent_config.py | 2 +- synapse/config/registration.py | 2 +- synapse/config/server_notices_config.py | 2 +- synapse/crypto/keyring.py | 4 +- synapse/federation/federation_client.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 10 ++--- synapse/groups/groups_server.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/auth.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/device.py | 4 +- synapse/handlers/directory.py | 2 +- synapse/handlers/events.py | 4 +- synapse/handlers/federation.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/initial_sync.py | 2 +- synapse/handlers/profile.py | 4 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member_worker.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/set_password.py | 2 +- synapse/handlers/user_directory.py | 2 +- synapse/http/__init__.py | 2 +- synapse/logging/formatter.py | 2 +- synapse/logging/scopecontextmanager.py | 6 +-- synapse/push/__init__.py | 2 +- synapse/replication/http/devices.py | 2 +- synapse/replication/http/federation.py | 8 ++-- synapse/replication/http/login.py | 2 +- synapse/replication/http/membership.py | 6 +-- synapse/replication/http/register.py | 4 +- synapse/replication/http/send_event.py | 2 +- synapse/replication/slave/storage/_base.py | 2 +- synapse/replication/slave/storage/account_data.py | 2 +- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/replication/slave/storage/deviceinbox.py | 2 +- synapse/replication/slave/storage/devices.py | 2 +- synapse/replication/slave/storage/events.py | 2 +- synapse/replication/slave/storage/filtering.py | 2 +- synapse/replication/slave/storage/groups.py | 2 +- synapse/replication/slave/storage/presence.py | 2 +- synapse/replication/slave/storage/pushers.py | 2 +- synapse/replication/slave/storage/receipts.py | 2 +- synapse/replication/slave/storage/room.py | 2 +- synapse/replication/tcp/streams/_base.py | 2 +- synapse/rest/admin/devices.py | 2 +- synapse/rest/client/v1/directory.py | 6 +-- synapse/rest/client/v1/events.py | 4 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/login.py | 4 +- synapse/rest/client/v1/logout.py | 4 +- synapse/rest/client/v1/presence.py | 2 +- synapse/rest/client/v1/profile.py | 6 +-- synapse/rest/client/v1/push_rule.py | 2 +- synapse/rest/client/v1/pusher.py | 6 +-- synapse/rest/client/v1/room.py | 38 ++++++++-------- synapse/rest/client/v1/voip.py | 2 +- synapse/rest/client/v2_alpha/account.py | 22 +++++----- synapse/rest/client/v2_alpha/account_data.py | 4 +- synapse/rest/client/v2_alpha/account_validity.py | 4 +- synapse/rest/client/v2_alpha/auth.py | 2 +- synapse/rest/client/v2_alpha/capabilities.py | 2 +- synapse/rest/client/v2_alpha/devices.py | 6 +-- synapse/rest/client/v2_alpha/filter.py | 4 +- synapse/rest/client/v2_alpha/groups.py | 48 ++++++++++----------- synapse/rest/client/v2_alpha/keys.py | 12 +++--- synapse/rest/client/v2_alpha/notifications.py | 2 +- synapse/rest/client/v2_alpha/openid.py | 2 +- synapse/rest/client/v2_alpha/password_policy.py | 2 +- synapse/rest/client/v2_alpha/read_marker.py | 2 +- synapse/rest/client/v2_alpha/receipts.py | 2 +- synapse/rest/client/v2_alpha/register.py | 10 ++--- synapse/rest/client/v2_alpha/relations.py | 8 ++-- synapse/rest/client/v2_alpha/report_event.py | 2 +- synapse/rest/client/v2_alpha/room_keys.py | 6 +-- .../client/v2_alpha/room_upgrade_rest_servlet.py | 2 +- synapse/rest/client/v2_alpha/sendtodevice.py | 2 +- synapse/rest/client/v2_alpha/shared_rooms.py | 2 +- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/client/v2_alpha/tags.py | 4 +- synapse/rest/client/v2_alpha/thirdparty.py | 8 ++-- synapse/rest/client/v2_alpha/tokenrefresh.py | 2 +- synapse/rest/client/v2_alpha/user_directory.py | 2 +- synapse/rest/client/versions.py | 2 +- synapse/storage/databases/main/__init__.py | 2 +- synapse/storage/databases/main/account_data.py | 4 +- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/deviceinbox.py | 4 +- synapse/storage/databases/main/devices.py | 4 +- synapse/storage/databases/main/event_federation.py | 2 +- .../storage/databases/main/event_push_actions.py | 4 +- .../storage/databases/main/events_bg_updates.py | 2 +- synapse/storage/databases/main/events_worker.py | 2 +- synapse/storage/databases/main/media_repository.py | 6 +-- .../storage/databases/main/monthly_active_users.py | 4 +- synapse/storage/databases/main/push_rule.py | 2 +- synapse/storage/databases/main/receipts.py | 4 +- synapse/storage/databases/main/registration.py | 6 +-- synapse/storage/databases/main/room.py | 6 +-- synapse/storage/databases/main/roommember.py | 6 +-- synapse/storage/databases/main/search.py | 4 +- synapse/storage/databases/main/state.py | 6 +-- synapse/storage/databases/main/stats.py | 2 +- synapse/storage/databases/main/stream.py | 2 +- synapse/storage/databases/main/transactions.py | 2 +- synapse/storage/databases/main/user_directory.py | 4 +- synapse/storage/databases/state/bg_updates.py | 2 +- synapse/storage/databases/state/store.py | 2 +- synapse/util/manhole.py | 2 +- synapse/util/retryutils.py | 2 +- tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_e2e_room_keys.py | 2 +- tests/replication/slave/storage/test_events.py | 2 +- tests/rest/test_well_known.py | 2 +- tests/server.py | 2 +- tests/storage/test_appservice.py | 2 +- tests/storage/test_devices.py | 2 +- tests/test_state.py | 2 +- tests/unittest.py | 2 +- 133 files changed, 272 insertions(+), 281 deletions(-) create mode 100644 changelog.d/8344.misc (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8344.misc b/changelog.d/8344.misc new file mode 100644 index 0000000000..0b342d5137 --- /dev/null +++ b/changelog.d/8344.misc @@ -0,0 +1 @@ +Simplify `super()` calls to Python 3 syntax. diff --git a/scripts-dev/definitions.py b/scripts-dev/definitions.py index 15e6ce6e16..313860df13 100755 --- a/scripts-dev/definitions.py +++ b/scripts-dev/definitions.py @@ -11,7 +11,7 @@ import yaml class DefinitionVisitor(ast.NodeVisitor): def __init__(self): - super(DefinitionVisitor, self).__init__() + super().__init__() self.functions = {} self.classes = {} self.names = {} diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 848a826f17..abcec48c4f 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -321,7 +321,7 @@ class MatrixConnectionAdapter(HTTPAdapter): url = urlparse.urlunparse( ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment) ) - return super(MatrixConnectionAdapter, self).get_connection(url, proxies) + return super().get_connection(url, proxies) if __name__ == "__main__": diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 94a9e58eae..cd6670d0a2 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -87,7 +87,7 @@ class CodeMessageException(RuntimeError): """ def __init__(self, code: Union[int, HTTPStatus], msg: str): - super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) + super().__init__("%d: %s" % (code, msg)) # Some calls to this method pass instances of http.HTTPStatus for `code`. # While HTTPStatus is a subclass of int, it has magic __str__ methods @@ -138,7 +138,7 @@ class SynapseError(CodeMessageException): msg: The human-readable error message. errcode: The matrix error code e.g 'M_FORBIDDEN' """ - super(SynapseError, self).__init__(code, msg) + super().__init__(code, msg) self.errcode = errcode def error_dict(self): @@ -159,7 +159,7 @@ class ProxiedRequestError(SynapseError): errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, ): - super(ProxiedRequestError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) if additional_fields is None: self._additional_fields = {} # type: Dict else: @@ -181,7 +181,7 @@ class ConsentNotGivenError(SynapseError): msg: The human-readable error message consent_url: The URL where the user can give their consent """ - super(ConsentNotGivenError, self).__init__( + super().__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri @@ -201,7 +201,7 @@ class UserDeactivatedError(SynapseError): Args: msg: The human-readable error message """ - super(UserDeactivatedError, self).__init__( + super().__init__( code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED ) @@ -225,7 +225,7 @@ class FederationDeniedError(SynapseError): self.destination = destination - super(FederationDeniedError, self).__init__( + super().__init__( code=403, msg="Federation denied with %s." % (self.destination,), errcode=Codes.FORBIDDEN, @@ -244,9 +244,7 @@ class InteractiveAuthIncompleteError(Exception): """ def __init__(self, session_id: str, result: "JsonDict"): - super(InteractiveAuthIncompleteError, self).__init__( - "Interactive auth not yet complete" - ) + super().__init__("Interactive auth not yet complete") self.session_id = session_id self.result = result @@ -261,14 +259,14 @@ class UnrecognizedRequestError(SynapseError): message = "Unrecognized request" else: message = args[0] - super(UnrecognizedRequestError, self).__init__(400, message, **kwargs) + super().__init__(400, message, **kwargs) class NotFoundError(SynapseError): """An error indicating we can't find the thing you asked for""" def __init__(self, msg: str = "Not found", errcode: str = Codes.NOT_FOUND): - super(NotFoundError, self).__init__(404, msg, errcode=errcode) + super().__init__(404, msg, errcode=errcode) class AuthError(SynapseError): @@ -279,7 +277,7 @@ class AuthError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(AuthError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class InvalidClientCredentialsError(SynapseError): @@ -335,7 +333,7 @@ class ResourceLimitError(SynapseError): ): self.admin_contact = admin_contact self.limit_type = limit_type - super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) + super().__init__(code, msg, errcode=errcode) def error_dict(self): return cs_error( @@ -352,7 +350,7 @@ class EventSizeError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.TOO_LARGE - super(EventSizeError, self).__init__(413, *args, **kwargs) + super().__init__(413, *args, **kwargs) class EventStreamError(SynapseError): @@ -361,7 +359,7 @@ class EventStreamError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.BAD_PAGINATION - super(EventStreamError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class LoginError(SynapseError): @@ -384,7 +382,7 @@ class InvalidCaptchaError(SynapseError): error_url: Optional[str] = None, errcode: str = Codes.CAPTCHA_INVALID, ): - super(InvalidCaptchaError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) self.error_url = error_url def error_dict(self): @@ -402,7 +400,7 @@ class LimitExceededError(SynapseError): retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super(LimitExceededError, self).__init__(code, msg, errcode) + super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms def error_dict(self): @@ -418,9 +416,7 @@ class RoomKeysVersionError(SynapseError): Args: current_version: the current version of the store they should have used """ - super(RoomKeysVersionError, self).__init__( - 403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION - ) + super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION) self.current_version = current_version @@ -429,7 +425,7 @@ class UnsupportedRoomVersionError(SynapseError): not support.""" def __init__(self, msg: str = "Homeserver does not support this room version"): - super(UnsupportedRoomVersionError, self).__init__( + super().__init__( code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -440,7 +436,7 @@ class ThreepidValidationError(SynapseError): def __init__(self, *args, **kwargs): if "errcode" not in kwargs: kwargs["errcode"] = Codes.FORBIDDEN - super(ThreepidValidationError, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class IncompatibleRoomVersionError(SynapseError): @@ -451,7 +447,7 @@ class IncompatibleRoomVersionError(SynapseError): """ def __init__(self, room_version: str): - super(IncompatibleRoomVersionError, self).__init__( + super().__init__( code=400, msg="Your homeserver does not support the features required to " "join this room", @@ -473,7 +469,7 @@ class PasswordRefusedError(SynapseError): msg: str = "This password doesn't comply with the server's policy", errcode: str = Codes.WEAK_PASSWORD, ): - super(PasswordRefusedError, self).__init__( + super().__init__( code=400, msg=msg, errcode=errcode, ) @@ -488,7 +484,7 @@ class RequestSendFailed(RuntimeError): """ def __init__(self, inner_exception, can_retry): - super(RequestSendFailed, self).__init__( + super().__init__( "Failed to send request: %s: %s" % (type(inner_exception).__name__, inner_exception) ) @@ -542,7 +538,7 @@ class FederationError(RuntimeError): self.source = source msg = "%s %s: %s" % (level, code, reason) - super(FederationError, self).__init__(msg) + super().__init__(msg) def get_dict(self): return { @@ -570,7 +566,7 @@ class HttpResponseException(CodeMessageException): msg: reason phrase from HTTP response status line response: body of response """ - super(HttpResponseException, self).__init__(code, msg) + super().__init__(code, msg) self.response = response def to_synapse_error(self): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bb33345be6..5caf336fd0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -132,7 +132,7 @@ def matrix_user_id_validator(user_id_str): class Filtering: def __init__(self, hs): - super(Filtering, self).__init__() + super().__init__() self.store = hs.get_datastore() async def get_user_filter(self, user_localpart, filter_id): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f985810e88..c38413c893 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -152,7 +152,7 @@ class PresenceStatusStubServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P[^/]*)/status") def __init__(self, hs): - super(PresenceStatusStubServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -176,7 +176,7 @@ class KeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(KeyUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.http_client = hs.get_simple_http_client() @@ -646,7 +646,7 @@ class GenericWorkerServer(HomeServer): class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): - super(GenericWorkerReplicationHandler, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index bb6fa8299a..1514c0f691 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -88,7 +88,7 @@ class ApplicationServiceApi(SimpleHttpClient): """ def __init__(self, hs): - super(ApplicationServiceApi, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.protocol_meta_cache = ResponseCache( diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index aec9c4bbce..fbddebeeab 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -77,7 +77,7 @@ class ConsentConfig(Config): section = "consent" def __init__(self, *args): - super(ConsentConfig, self).__init__(*args) + super().__init__(*args) self.user_consent_version = None self.user_consent_template_dir = None diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a185655774..5ffbb934fe 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -30,7 +30,7 @@ class AccountValidityConfig(Config): def __init__(self, config, synapse_config): if config is None: return - super(AccountValidityConfig, self).__init__() + super().__init__() self.enabled = config.get("enabled", False) self.renew_by_email_enabled = "renew_at" in config diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py index 6c427b6f92..57f69dc8e2 100644 --- a/synapse/config/server_notices_config.py +++ b/synapse/config/server_notices_config.py @@ -62,7 +62,7 @@ class ServerNoticesConfig(Config): section = "servernotices" def __init__(self, *args): - super(ServerNoticesConfig, self).__init__(*args) + super().__init__(*args) self.server_notices_mxid = None self.server_notices_mxid_display_name = None self.server_notices_mxid_avatar_url = None diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 32c31b1cd1..42e4087a92 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -558,7 +558,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the "perspectives" servers""" def __init__(self, hs): - super(PerspectivesKeyFetcher, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_http_client() self.key_servers = self.config.key_servers @@ -728,7 +728,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): """KeyFetcher impl which fetches keys from the origin servers""" def __init__(self, hs): - super(ServerKeyFetcher, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() self.client = hs.get_http_client() diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index a2e8d96ea2..639d19f696 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -79,7 +79,7 @@ class InvalidResponseError(RuntimeError): class FederationClient(FederationBase): def __init__(self, hs): - super(FederationClient, self).__init__(hs) + super().__init__(hs) self.pdu_destination_tried = {} self._clock.looping_call(self._clear_tried_cache, 60 * 1000) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ff00f0b302..2dcd081cbc 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -90,7 +90,7 @@ pdu_process_time = Histogram( class FederationServer(FederationBase): def __init__(self, hs): - super(FederationServer, self).__init__(hs) + super().__init__(hs) self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cc7e9a973b..3a6b95631e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -68,7 +68,7 @@ class TransportLayerServer(JsonResource): self.clock = hs.get_clock() self.servlet_groups = servlet_groups - super(TransportLayerServer, self).__init__(hs, canonical_json=False) + super().__init__(hs, canonical_json=False) self.authenticator = Authenticator(hs) self.ratelimiter = hs.get_federation_ratelimiter() @@ -376,9 +376,7 @@ class FederationSendServlet(BaseFederationServlet): RATELIMIT = False def __init__(self, handler, server_name, **kwargs): - super(FederationSendServlet, self).__init__( - handler, server_name=server_name, **kwargs - ) + super().__init__(handler, server_name=server_name, **kwargs) self.server_name = server_name # This is when someone is trying to send us a bunch of data. @@ -773,9 +771,7 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): - super(PublicRoomList, self).__init__( - handler, authenticator, ratelimiter, server_name - ) + super().__init__(handler, authenticator, ratelimiter, server_name) self.allow_access = allow_access async def on_GET(self, origin, content, query): diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 1dd20ee4e1..e5f85b472d 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -336,7 +336,7 @@ class GroupsServerWorkerHandler: class GroupsServerHandler(GroupsServerWorkerHandler): def __init__(self, hs): - super(GroupsServerHandler, self).__init__(hs) + super().__init__(hs) # Ensure attestations get renewed hs.get_groups_attestation_renewer() diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 5e5a64037d..dd981c597e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class AdminHandler(BaseHandler): def __init__(self, hs): - super(AdminHandler, self).__init__(hs) + super().__init__(hs) self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4e658d9a48..0322b60cfc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -145,7 +145,7 @@ class AuthHandler(BaseHandler): Args: hs (synapse.server.HomeServer): """ - super(AuthHandler, self).__init__(hs) + super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 25169157c1..0635ad5708 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -29,7 +29,7 @@ class DeactivateAccountHandler(BaseHandler): """Handler which deals with deactivating user accounts.""" def __init__(self, hs): - super(DeactivateAccountHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 4b0a4f96cc..55a9787439 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -48,7 +48,7 @@ MAX_DEVICE_DISPLAY_NAME_LEN = 100 class DeviceWorkerHandler(BaseHandler): def __init__(self, hs): - super(DeviceWorkerHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.state = hs.get_state_handler() @@ -251,7 +251,7 @@ class DeviceWorkerHandler(BaseHandler): class DeviceHandler(DeviceWorkerHandler): def __init__(self, hs): - super(DeviceHandler, self).__init__(hs) + super().__init__(hs) self.federation_sender = hs.get_federation_sender() diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 46826eb784..62aa9a2da8 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class DirectoryHandler(BaseHandler): def __init__(self, hs): - super(DirectoryHandler, self).__init__(hs) + super().__init__(hs) self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index fdce54c5c3..0875b74ea8 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class EventStreamHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(EventStreamHandler, self).__init__(hs) + super().__init__(hs) self.clock = hs.get_clock() @@ -142,7 +142,7 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(EventHandler, self).__init__(hs) + super().__init__(hs) self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 262901363f..96eeff7b1b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -115,7 +115,7 @@ class FederationHandler(BaseHandler): """ def __init__(self, hs): - super(FederationHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 44df567983..9684e60fc8 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -240,7 +240,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): def __init__(self, hs): - super(GroupsLocalHandler, self).__init__(hs) + super().__init__(hs) # Ensure attestations get renewed hs.get_groups_attestation_renewer() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0ce6ddfbe4..ab15570f7a 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -45,7 +45,7 @@ id_server_scheme = "https://" class IdentityHandler(BaseHandler): def __init__(self, hs): - super(IdentityHandler, self).__init__(hs) + super().__init__(hs) self.http_client = SimpleHttpClient(hs) # We create a blacklisting instance of SimpleHttpClient for contacting identity diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index ba4828c713..8cd7eb22a3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) class InitialSyncHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(InitialSyncHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 0cb8fad89a..5453e6dfc8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -44,7 +44,7 @@ class BaseProfileHandler(BaseHandler): """ def __init__(self, hs): - super(BaseProfileHandler, self).__init__(hs) + super().__init__(hs) self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -369,7 +369,7 @@ class MasterProfileHandler(BaseProfileHandler): PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs): - super(MasterProfileHandler, self).__init__(hs) + super().__init__(hs) assert hs.config.worker_app is None diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index e3b528d271..c32f314a1c 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) class ReadMarkerHandler(BaseHandler): def __init__(self, hs): - super(ReadMarkerHandler, self).__init__(hs) + super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() self.read_marker_linearizer = Linearizer(name="read_marker") diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index bdd8e52edd..7225923757 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) class ReceiptsHandler(BaseHandler): def __init__(self, hs): - super(ReceiptsHandler, self).__init__(hs) + super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cde2dbca92..538f4b2a61 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -42,7 +42,7 @@ class RegistrationHandler(BaseHandler): Args: hs (synapse.server.HomeServer): """ - super(RegistrationHandler, self).__init__(hs) + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index eeade6ad3f..11bf146bed 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -70,7 +70,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000 class RoomCreationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): - super(RoomCreationHandler, self).__init__(hs) + super().__init__(hs) self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 5dd7b28391..4a13c8e912 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -38,7 +38,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs): - super(RoomListHandler, self).__init__(hs) + super().__init__(hs) self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache(hs, "room_list") self.remote_response_cache = ResponseCache( diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index e7f34737c6..f2e88f6a5b 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class RoomMemberWorkerHandler(RoomMemberHandler): def __init__(self, hs): - super(RoomMemberWorkerHandler, self).__init__(hs) + super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index d58f9788c5..6a76c20d79 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): def __init__(self, hs): - super(SearchHandler, self).__init__(hs) + super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 4d245b618b..a5d67f828f 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -27,7 +27,7 @@ class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" def __init__(self, hs): - super(SetPasswordHandler, self).__init__(hs) + super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._password_policy_handler = hs.get_password_policy_handler() diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index e21f8dbc58..79393c8829 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -37,7 +37,7 @@ class UserDirectoryHandler(StateDeltasHandler): """ def __init__(self, hs): - super(UserDirectoryHandler, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.state = hs.get_state_handler() diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index 3880ce0d94..8eb3638591 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -27,7 +27,7 @@ class RequestTimedOutError(SynapseError): """Exception representing timeout of an outbound request""" def __init__(self): - super(RequestTimedOutError, self).__init__(504, "Timed out") + super().__init__(504, "Timed out") def cancelled_to_request_timed_out_error(value, timeout): diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index d736ad5b9b..11f60a77f7 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -30,7 +30,7 @@ class LogFormatter(logging.Formatter): """ def __init__(self, *args, **kwargs): - super(LogFormatter, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def formatException(self, ei): sio = StringIO() diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index 026854b4c7..7b9c657456 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -107,7 +107,7 @@ class _LogContextScope(Scope): finish_on_close (Boolean): if True finish the span when the scope is closed """ - super(_LogContextScope, self).__init__(manager, span) + super().__init__(manager, span) self.logcontext = logcontext self._finish_on_close = finish_on_close self._enter_logcontext = enter_logcontext @@ -120,9 +120,9 @@ class _LogContextScope(Scope): def __exit__(self, type, value, traceback): if type == twisted.internet.defer._DefGen_Return: - super(_LogContextScope, self).__exit__(None, None, None) + super().__exit__(None, None, None) else: - super(_LogContextScope, self).__exit__(type, value, traceback) + super().__exit__(type, value, traceback) if self._enter_logcontext: self.logcontext.__exit__(type, value, traceback) else: # the logcontext existed before the creation of the scope diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index edf45dc599..5a437f9810 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -16,4 +16,4 @@ class PusherConfigException(Exception): def __init__(self, msg): - super(PusherConfigException, self).__init__(msg) + super().__init__(msg) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 20f3ba76c0..807b85d2e1 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -53,7 +53,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): CACHE = False def __init__(self, hs): - super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs) + super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater self.store = hs.get_datastore() diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5c8be747e1..5393b9a9e7 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -57,7 +57,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): PATH_ARGS = () def __init__(self, hs): - super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.storage = hs.get_storage() @@ -150,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): PATH_ARGS = ("edu_type",) def __init__(self, hs): - super(ReplicationFederationSendEduRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -193,7 +193,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): CACHE = False def __init__(self, hs): - super(ReplicationGetQueryRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -236,7 +236,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id",) def __init__(self, hs): - super(ReplicationCleanRoomRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index fb326bb869..4c81e2d784 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -32,7 +32,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(RegisterDeviceReplicationServlet, self).__init__(hs) + super().__init__(hs) self.registration_handler = hs.get_registration_handler() @staticmethod diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 08095fdf7d..30680baee8 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id", "user_id") def __init__(self, hs): - super(ReplicationRemoteJoinRestServlet, self).__init__(hs) + super().__init__(hs) self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() @@ -107,7 +107,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): PATH_ARGS = ("invite_event_id",) def __init__(self, hs: "HomeServer"): - super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -168,7 +168,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): CACHE = False # No point caching as should return instantly. def __init__(self, hs): - super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__(hs) + super().__init__(hs) self.registeration_handler = hs.get_registration_handler() self.store = hs.get_datastore() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index a02b27474d..7b12ec9060 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -29,7 +29,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(ReplicationRegisterServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -104,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) def __init__(self, hs): - super(ReplicationPostRegisterActionsServlet, self).__init__(hs) + super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index f13d452426..9a3a694d5d 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -52,7 +52,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): PATH_ARGS = ("event_id",) def __init__(self, hs): - super(ReplicationSendEventRestServlet, self).__init__(hs) + super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 60f2e1245f..d25fa49e1a 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(BaseSlavedStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = MultiWriterIdGenerator( db_conn, diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bb66ba9b80..4268565fc8 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -34,7 +34,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved ], ) - super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index a6fdedde63..1f8dafe7ea 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -22,7 +22,7 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedClientIpStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 533d927701..5b045bed02 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_inbox", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 3b788c9625..e0d86240dd 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -24,7 +24,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedDeviceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index da1cc836cf..fbffe6d85c 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -56,7 +56,7 @@ class SlavedEventStore( BaseSlavedStore, ): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedEventStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 2562b6fc38..6a23252861 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -21,7 +21,7 @@ from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedFilteringStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 567b4a5cc1..30955bcbfe 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 025f6f6be8..55620c03d8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -25,7 +25,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedPresenceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 9da218bfe8..c418730ba8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SlavedPusherStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 5c2986e050..6195917376 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -30,7 +30,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 80ae803ad9..109ac6bea1 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -23,7 +23,7 @@ from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 1f609f158c..54dccd15a6 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -345,7 +345,7 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() - super(PushRulesStream, self).__init__( + super().__init__( hs.get_instance_name(), self._current_token, self.store.get_all_push_rule_updates, diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 4670d7160d..a163863322 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -36,7 +36,7 @@ class DeviceRestServlet(RestServlet): ) def __init__(self, hs): - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index b210015173..faabeeb91c 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet): PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet): PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): ) def __init__(self, hs): - super(ClientAppserviceDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 25effd0261..985d994f6b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -30,7 +30,7 @@ class EventStreamRestServlet(RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 def __init__(self, hs): - super(EventStreamRestServlet, self).__init__() + super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() @@ -74,7 +74,7 @@ class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) def __init__(self, hs): - super(EventRestServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 910b3b4eeb..d7042786ce 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -24,7 +24,7 @@ class InitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/initialSync$", v1=True) def __init__(self, hs): - super(InitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index dd8cdc0d9f..250b03a025 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -48,7 +48,7 @@ class LoginRestServlet(RestServlet): APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" def __init__(self, hs): - super(LoginRestServlet, self).__init__() + super().__init__() self.hs = hs # JWT configuration variables. @@ -429,7 +429,7 @@ class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) def __init__(self, hs): - super(CasTicketServlet, self).__init__() + super().__init__() self._cas_handler = hs.get_cas_handler() async def on_GET(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index b0c30b65be..f792b50cdc 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet): PATTERNS = client_patterns("/logout$", v1=True) def __init__(self, hs): - super(LogoutRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet): PATTERNS = client_patterns("/logout/all$", v1=True) def __init__(self, hs): - super(LogoutAllRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 970fdd5834..79d8e3057f 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) def __init__(self, hs): - super(PresenceStatusRestServlet, self).__init__() + super().__init__() self.hs = hs self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e7fe50ed72..b686cd671f 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -25,7 +25,7 @@ class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) def __init__(self, hs): - super(ProfileDisplaynameRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() @@ -73,7 +73,7 @@ class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) def __init__(self, hs): - super(ProfileAvatarURLRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() @@ -124,7 +124,7 @@ class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) def __init__(self, hs): - super(ProfileRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index ddf8ed5e9c..f9eecb7cf5 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -38,7 +38,7 @@ class PushRuleRestServlet(RestServlet): ) def __init__(self, hs): - super(PushRuleRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5f65cb7d83..28dabf1c7a 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) def __init__(self, hs): - super(PushersRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) def __init__(self, hs): - super(PushersSetRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet): SUCCESS_HTML = b"You have been unsubscribed" def __init__(self, hs): - super(PushersRemoveRestServlet, self).__init__() + super().__init__() self.hs = hs self.notifier = hs.get_notifier() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 84baf3d59b..7e64a2e0fe 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) class TransactionRestServlet(RestServlet): def __init__(self, hs): - super(TransactionRestServlet, self).__init__() + super().__init__() self.txns = HttpTransactionCache(hs) @@ -65,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here def __init__(self, hs): - super(RoomCreateRestServlet, self).__init__(hs) + super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() @@ -111,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomStateEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() @@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomSendEventRestServlet, self).__init__(hs) + super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -280,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): - super(JoinRoomAliasServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -343,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) def __init__(self, hs): - super(PublicRoomListRestServlet, self).__init__(hs) + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() @@ -448,7 +448,7 @@ class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) def __init__(self, hs): - super(RoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -499,7 +499,7 @@ class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) def __init__(self, hs): - super(JoinedRoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -518,7 +518,7 @@ class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) def __init__(self, hs): - super(RoomMessageListRestServlet, self).__init__() + super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() @@ -557,7 +557,7 @@ class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) def __init__(self, hs): - super(RoomStateRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -577,7 +577,7 @@ class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) def __init__(self, hs): - super(RoomInitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() @@ -596,7 +596,7 @@ class RoomEventServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() @@ -628,7 +628,7 @@ class RoomEventContextServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventContextServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -675,7 +675,7 @@ class RoomEventContextServlet(RestServlet): class RoomForgetRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomForgetRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -701,7 +701,7 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomMembershipRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -792,7 +792,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomRedactEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -841,7 +841,7 @@ class RoomTypingRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomTypingRestServlet, self).__init__() + super().__init__() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() @@ -914,7 +914,7 @@ class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) def __init__(self, hs): - super(SearchRestServlet, self).__init__() + super().__init__() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -935,7 +935,7 @@ class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) def __init__(self, hs): - super(JoinedRoomsRestServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 50277c6cf6..b8d491ca5c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet): PATTERNS = client_patterns("/voip/turnServer$", v1=True) def __init__(self, hs): - super(VoipRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ade97a6708..c3ce0f6259 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -52,7 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") def __init__(self, hs): - super(EmailPasswordRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.datastore = hs.get_datastore() self.config = hs.config @@ -156,7 +156,7 @@ class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") def __init__(self, hs): - super(PasswordRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -282,7 +282,7 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") def __init__(self, hs): - super(DeactivateAccountRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -330,7 +330,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") def __init__(self, hs): - super(EmailThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.identity_handler = hs.get_handlers().identity_handler @@ -427,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs): self.hs = hs - super(MsisdnThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler @@ -606,7 +606,7 @@ class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") def __init__(self, hs): - super(ThreepidRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -662,7 +662,7 @@ class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") def __init__(self, hs): - super(ThreepidAddRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -713,7 +713,7 @@ class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") def __init__(self, hs): - super(ThreepidBindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -742,7 +742,7 @@ class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") def __init__(self, hs): - super(ThreepidUnbindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -773,7 +773,7 @@ class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") def __init__(self, hs): - super(ThreepidDeleteRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -852,7 +852,7 @@ class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") def __init__(self, hs): - super(WhoamiRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() async def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index c1d4cd0caf..87a5b1b86b 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -34,7 +34,7 @@ class AccountDataServlet(RestServlet): ) def __init__(self, hs): - super(AccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -86,7 +86,7 @@ class RoomAccountDataServlet(RestServlet): ) def __init__(self, hs): - super(RoomAccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index d06336ceea..bd7f9ae203 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValidityRenewServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() @@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValiditySendMailServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8e585e9153..097538f968 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -124,7 +124,7 @@ class AuthRestServlet(RestServlet): PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): - super(AuthRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index fe9d019c44..76879ac559 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(CapabilitiesRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index c0714fcfb1..7e174de692 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -35,7 +35,7 @@ class DevicesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -57,7 +57,7 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): - super(DeleteDevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -102,7 +102,7 @@ class DeviceRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index b28da017cd..7cc692643b 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") def __init__(self, hs): - super(GetFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() @@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter") def __init__(self, hs): - super(CreateFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 13ecf7005d..a3bb095c2d 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -32,7 +32,7 @@ class GroupServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") def __init__(self, hs): - super(GroupServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -66,7 +66,7 @@ class GroupSummaryServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") def __init__(self, hs): - super(GroupSummaryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -97,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryRoomsCatServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -137,7 +137,7 @@ class GroupCategoryServlet(RestServlet): ) def __init__(self, hs): - super(GroupCategoryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -181,7 +181,7 @@ class GroupCategoriesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") def __init__(self, hs): - super(GroupCategoriesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -204,7 +204,7 @@ class GroupRoleServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") def __init__(self, hs): - super(GroupRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -248,7 +248,7 @@ class GroupRolesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") def __init__(self, hs): - super(GroupRolesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -279,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryUsersRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -317,7 +317,7 @@ class GroupRoomServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") def __init__(self, hs): - super(GroupRoomServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -343,7 +343,7 @@ class GroupUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") def __init__(self, hs): - super(GroupUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -366,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") def __init__(self, hs): - super(GroupInvitedUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -389,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") def __init__(self, hs): - super(GroupSettingJoinPolicyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() @@ -413,7 +413,7 @@ class GroupCreateServlet(RestServlet): PATTERNS = client_patterns("/create_group$") def __init__(self, hs): - super(GroupCreateServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -444,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -481,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsConfigServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -507,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -536,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersKickServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -560,7 +560,7 @@ class GroupSelfLeaveServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") def __init__(self, hs): - super(GroupSelfLeaveServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -584,7 +584,7 @@ class GroupSelfJoinServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") def __init__(self, hs): - super(GroupSelfJoinServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -608,7 +608,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") def __init__(self, hs): - super(GroupSelfAcceptInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -632,7 +632,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") def __init__(self, hs): - super(GroupSelfUpdatePublicityServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -655,7 +655,7 @@ class PublicisedGroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") def __init__(self, hs): - super(PublicisedGroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -676,7 +676,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): - super(PublicisedGroupsForUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -700,7 +700,7 @@ class GroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): - super(GroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 24bb090822..7abd6ff333 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(KeyUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -147,7 +147,7 @@ class KeyQueryServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyQueryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -177,7 +177,7 @@ class KeyChangesServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyChangesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -222,7 +222,7 @@ class OneTimeKeyServlet(RestServlet): PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): - super(OneTimeKeyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -250,7 +250,7 @@ class SigningKeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SigningKeyUploadServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -308,7 +308,7 @@ class SignaturesUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SignaturesUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index aa911d75ee..87063ec8b1 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") def __init__(self, hs): - super(NotificationsServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index 6ae9a5a8e9..5b996e2d63 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet): EXPIRES_MS = 3600 * 1000 def __init__(self, hs): - super(IdTokenServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py index 968403cca4..68b27ff23a 100644 --- a/synapse/rest/client/v2_alpha/password_policy.py +++ b/synapse/rest/client/v2_alpha/password_policy.py @@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(PasswordPolicyServlet, self).__init__() + super().__init__() self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index 67cbc37312..55c6688f52 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") def __init__(self, hs): - super(ReadMarkerRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 92555bd4a9..6f7246a394 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet): ) def __init__(self, hs): - super(ReceiptRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 0705718d00..ffa2dfce42 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -76,7 +76,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(EmailRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.config = hs.config @@ -174,7 +174,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(MsisdnRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -249,7 +249,7 @@ class RegistrationSubmitTokenServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegistrationSubmitTokenServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.config = hs.config @@ -319,7 +319,7 @@ class UsernameAvailabilityRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UsernameAvailabilityRestServlet, self).__init__() + super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( @@ -363,7 +363,7 @@ class RegisterRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegisterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index e29f49f7f5..18c75738f8 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -61,7 +61,7 @@ class RelationSendServlet(RestServlet): ) def __init__(self, hs): - super(RelationSendServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) @@ -138,7 +138,7 @@ class RelationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -233,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() @@ -311,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationGroupPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e15927c4ea..215d619ca1 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") def __init__(self, hs): - super(ReportEventRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 59529707df..53de97923f 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysNewVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index 39a5518614..bf030e0ff4 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -53,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomUpgradeRestServlet, self).__init__() + super().__init__() self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index db829f3098..bc4f43639a 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SendToDeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py index 2492634dac..c866d5151c 100644 --- a/synapse/rest/client/v2_alpha/shared_rooms.py +++ b/synapse/rest/client/v2_alpha/shared_rooms.py @@ -34,7 +34,7 @@ class UserSharedRoomsServlet(RestServlet): ) def __init__(self, hs): - super(UserSharedRoomsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.user_directory_active = hs.config.update_user_directory diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a0b00135e1..51e395cc64 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -74,7 +74,7 @@ class SyncRestServlet(RestServlet): ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): - super(SyncRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.sync_handler = hs.get_sync_handler() diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index a3f12e8a77..bf3a79db44 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -31,7 +31,7 @@ class TagListServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") def __init__(self, hs): - super(TagListServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -56,7 +56,7 @@ class TagServlet(RestServlet): ) def __init__(self, hs): - super(TagServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 23709960ad..0c127a1b5f 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocols") def __init__(self, hs): - super(ThirdPartyProtocolsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") def __init__(self, hs): - super(ThirdPartyProtocolServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") def __init__(self, hs): - super(ThirdPartyUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") def __init__(self, hs): - super(ThirdPartyLocationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 83f3b6b70a..79317c74ba 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet): PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): - super(TokenRefreshRestServlet, self).__init__() + super().__init__() async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index bef91a2d3e..ad598cefe0 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -31,7 +31,7 @@ class UserDirectorySearchRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UserDirectorySearchRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 24ac57f35d..d5018afbda 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -28,7 +28,7 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def __init__(self, hs): - super(VersionsRestServlet, self).__init__() + super().__init__() self.config = hs.config def on_GET(self, request): diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 2ae2fbd5d7..ccb3384db9 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -172,7 +172,7 @@ class DataStore( else: self._cache_id_gen = None - super(DataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 5f1a2b9aa6..c5a36990e4 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -42,7 +42,7 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -313,7 +313,7 @@ class AccountDataStore(AccountDataWorkerStore): ], ) - super(AccountDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_account_data_stream_id(self) -> int: """Get the current max stream id for the private user data stream diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 454c0bc50c..85f6b1e3fd 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c2fc847fbc..239c7a949c 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "user_ips_device_index", @@ -358,7 +358,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): name="client_ip_last_seen", keylen=4, max_entries=50000 ) - super(ClientIpStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 0044433110..e71217a41f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_inbox_stream_index", @@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceInboxStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 306fc6947c..c04374e43d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -701,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( "device_lists_stream_idx", @@ -826,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(DeviceStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4c3c162acf..6d3689c09e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventFederationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7805fb814e..62f1738732 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -68,7 +68,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -661,7 +661,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventPushActionsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index e53c6373a8..5e4af2eb51 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cd3739c16c..de9e8d1dc6 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -75,7 +75,7 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(EventsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 1d76c761a6..cc538c5c10 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -24,9 +24,7 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__( - database, db_conn, hs - ) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( update_name="local_media_repository_url_idx", @@ -94,7 +92,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" def __init__(self, database: DatabasePool, db_conn, hs): - super(MediaRepositoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 1d793d3deb..e0cedd1aac 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -28,7 +28,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -120,7 +120,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_stats_only = hs.config.mau_stats_only diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index b7a8d34ce1..e20a16f907 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -77,7 +77,7 @@ class PushRulesWorkerStore( """ def __init__(self, database: DatabasePool, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: self._push_rules_stream_id_gen = StreamIdGenerator( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6568bddd81..f880b5e562 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -39,7 +39,7 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -386,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore): db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 01f20c03c2..675e81fe34 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @@ -764,7 +764,7 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config @@ -892,7 +892,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RegistrationStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 127588ce4c..bd6f9553c6 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -69,7 +69,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -863,7 +863,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config @@ -1074,7 +1074,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 91a8b43da3..4fa8767b01 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -55,7 +55,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -819,7 +819,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) @@ -973,7 +973,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(RoomMemberStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f01cf2fd02..e34fce6281 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) if not hs.config.enable_search: return @@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(SearchStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 5c6168e301..3c1e33819b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: """Get the room_version of a given room @@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" def __init__(self, database: DatabasePool, db_conn, hs): - super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 30840dbbaa..d7816a8606 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -61,7 +61,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(StatsStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 7dbe11513b..5dac78e574 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -266,7 +266,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): """ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): - super(StreamWorkerStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() self._send_federation = hs.should_send_federation() diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 091367006e..99cffff50c 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -48,7 +48,7 @@ class TransactionStore(SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(TransactionStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f2f9a5799a..5a390ff2f6 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SHARE_PRIVATE_WORKING_SET = 500 def __init__(self, database: DatabasePool, db_conn, hs): - super(UserDirectoryStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 139085b672..acb24e33af 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" def __init__(self, database: DatabasePool, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e924f1ca3b..bec3780a32 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): """ def __init__(self, database: DatabasePool, db_conn, hs): - super(StateGroupDataStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 631654f297..da24ba0470 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -94,7 +94,7 @@ class SynapseManhole(ColoredManhole): """Overrides connectionMade to create our own ManholeInterpreter""" def connectionMade(self): - super(SynapseManhole, self).connectionMade() + super().connectionMade() # replace the manhole interpreter with our own impl self.interpreter = SynapseManholeInterpreter(self, self.namespace) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 79869aaa44..a5cc9d0551 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -45,7 +45,7 @@ class NotRetryingDestination(Exception): """ msg = "Not retrying server %s." % (destination,) - super(NotRetryingDestination, self).__init__(msg) + super().__init__(msg) self.retry_last_ts = retry_last_ts self.retry_interval = retry_interval diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 210ddcbb88..366dcfb670 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -30,7 +30,7 @@ from tests import unittest, utils class E2eKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): - super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 3362050ce0..7adde9b9de 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -47,7 +47,7 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): - super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 561258a356..bc578411d6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -58,7 +58,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # Patch up the equality operator for events so that we can check # whether lists of events match using assertEquals self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] - return super(SlavedEventStoreTestCase, self).setUp() + return super().setUp() def prepare(self, *args, **kwargs): super().prepare(*args, **kwargs) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index b090bb974c..dcd65c2a50 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -21,7 +21,7 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): def setUp(self): - super(WellKnownTests, self).setUp() + super().setUp() # replace the JsonResource with a WellKnownResource self.resource = WellKnownResource(self.hs) diff --git a/tests/server.py b/tests/server.py index 61ec670155..b404ad4e2a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -260,7 +260,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) - super(ThreadedMemoryReactorClock, self).__init__() + super().__init__() def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cb808d4de4..46f94914ff 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -413,7 +413,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(TestTransactionStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 34ae8c9da7..ecb00f4e02 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -23,7 +23,7 @@ import tests.utils class DeviceStoreTestCase(tests.unittest.TestCase): def __init__(self, *args, **kwargs): - super(DeviceStoreTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore @defer.inlineCallbacks diff --git a/tests/test_state.py b/tests/test_state.py index 2d58467932..80b0ccbc40 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -125,7 +125,7 @@ class StateGroupStore: class DictObj(dict): def __init__(self, **kwargs): - super(DictObj, self).__init__(kwargs) + super().__init__(kwargs) self.__dict__ = self diff --git a/tests/unittest.py b/tests/unittest.py index 128dd4e19c..dabf69cff4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -92,7 +92,7 @@ class TestCase(unittest.TestCase): root logger's logging level while that test (case|method) runs.""" def __init__(self, methodName, *args, **kwargs): - super(TestCase, self).__init__(methodName, *args, **kwargs) + super().__init__(methodName, *args, **kwargs) method = getattr(self, methodName) -- cgit 1.5.1 From f112cfe5bb2c918c9e942941686a05664d8bd7da Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 24 Sep 2020 16:53:51 +0100 Subject: Fix MultiWriteIdGenerator's handling of restarts. (#8374) On startup `MultiWriteIdGenerator` fetches the maximum stream ID for each instance from the table and uses that as its initial "current position" for each writer. This is problematic as a) it involves either a scan of events table or an index (neither of which is ideal), and b) if rows are being persisted out of order elsewhere while the process restarts then using the maximum stream ID is not correct. This could theoretically lead to race conditions where e.g. events that are persisted out of order are not sent down sync streams. We fix this by creating a new table that tracks the current positions of each writer to the stream, and update it each time we finish persisting a new entry. This is a relatively small overhead when persisting events. However for the cache invalidation stream this is a much bigger relative overhead, so instead we note that for invalidation we don't actually care about reliability over restarts (as there's no caches to invalidate) and simply don't bother reading and writing to the new table in that particular case. --- changelog.d/8374.bugfix | 1 + synapse/replication/slave/storage/_base.py | 2 + synapse/storage/databases/main/__init__.py | 8 +- synapse/storage/databases/main/events_worker.py | 4 + .../main/schema/delta/58/18stream_positions.sql | 22 +++ synapse/storage/util/id_generators.py | 148 ++++++++++++++++++--- tests/storage/test_id_generators.py | 119 +++++++++++++++-- 7 files changed, 274 insertions(+), 30 deletions(-) create mode 100644 changelog.d/8374.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/58/18stream_positions.sql (limited to 'synapse/storage/databases/main/events_worker.py') diff --git a/changelog.d/8374.bugfix b/changelog.d/8374.bugfix new file mode 100644 index 0000000000..155bc3404f --- /dev/null +++ b/changelog.d/8374.bugfix @@ -0,0 +1 @@ +Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d25fa49e1a..d0089fe06c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, + stream_name="caches", instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index ccb3384db9..0cb12f4c61 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,14 +160,20 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, - instance_name="master", + stream_name="caches", + instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index de9e8d1dc6..f95679ebc4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="events", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, ) self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="backfill", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_backfill_stream_seq", positive=False, + writers=hs.config.worker.writers.events, ) else: # We shouldn't be running in worker mode with SQLite, but its useful diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql new file mode 100644 index 0000000000..985fd949a2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b0353ac2dc..727fcc521c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, Union import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -184,12 +185,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +203,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,9 +225,7 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). @@ -251,30 +258,80 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - + ): cur = db_conn.cursor() - cur.execute(sql) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + # Load the current positions of all writers for the stream. + if self._writers: + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + sql = self._db.engine.convert_param_style(sql) - cur.close() + cur.execute(sql, (self._stream_name,)) + + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } + + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + sql = """ + SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (min_stream_id,)) + + self._persisted_upto_position = min_stream_id + + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) - return current_positions + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -316,6 +373,21 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): @@ -447,6 +519,28 @@ class MultiWriterIdGenerator: # do. break + def _update_stream_positions_table_txn(self, txn): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + @attr.s(slots=True) class _AsyncCtxManagerWrapper: @@ -503,4 +597,16 @@ class _MultiWriterCtxManager: if exc_type is not None: return False + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + ) + return False diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index fb8f5bc255..d4ff55fbff 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, ) return self.get_success(self.db_pool.runWithConnection(_create)) @@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", (instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) @@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, stream_id, stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) @@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_rows("first", 3) self._insert_rows("second", 4) - first_id_gen = self._create_id_generator("first") - second_id_gen = self._create_id_generator("second") + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) @@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -300,7 +318,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -319,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + def test_restart_during_out_of_order_persistence(self): + """Test that restarting a process while another process is writing out + of order updates are handled correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # Persist two rows at once + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # We finish persisting the second row before restart + self.get_success(ctx2.__aexit__(None, None, None)) + + # We simulate a restart of another worker by just creating a new ID gen. + id_gen_worker = self._create_id_generator("worker") + + # Restarted worker should not see the second persisted row + self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) + self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) + + # Now if we persist the first row then both instances should jump ahead + # correctly. + self.get_success(ctx1.__aexit__(None, None, None)) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + id_gen_worker.advance("master", 9) + self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) + + def test_writer_config_change(self): + """Test that changing the writer config correctly works. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + # Initial config has two writers + id_gen = self._create_id_generator("first", writers=["first", "second"]) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # New config removes one of the configs. Note that if the writer is + # removed from config we assume that it has been shut down and has + # finished persisting, hence why the persisted upto position is 5. + id_gen_2 = self._create_id_generator("second", writers=["second"]) + self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + + # This config points to a single, previously unused writer. + id_gen_3 = self._create_id_generator("third", writers=["third"]) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + + # Check that we get a sane next stream ID with this new config. + + async def _get_next_async(): + async with id_gen_3.get_next() as stream_id: + self.assertEqual(stream_id, 6) + + self.get_success(_get_next_async()) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs. @@ -345,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, positive=False, ) @@ -368,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): txn.execute( "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, -stream_id, -stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) @@ -409,8 +512,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests that having multiple instances that get advanced over federation works corretly. """ - id_gen_1 = self._create_id_generator("first") - id_gen_2 = self._create_id_generator("second") + id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) + id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) async def _get_next_async(): async with id_gen_1.get_next() as stream_id: -- cgit 1.5.1