From fe8ed1b46f781faa45d1bba8f9308cf47c42010f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Sep 2020 12:22:55 +0100 Subject: Make `StreamToken.room_key` be a `RoomStreamToken` instance. (#8281) --- synapse/storage/databases/main/events.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'synapse/storage/databases/main/events.py') diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b3d27a2ee7..9cd1403b38 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -213,7 +213,7 @@ class PersistEventsStore: Returns: Filtered event ids """ - results = [] + results = [] # type: List[str] def _get_events_which_are_prevs_txn(txn, batch): sql = """ @@ -631,7 +631,9 @@ class PersistEventsStore: ) @classmethod - def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + def _filter_events_and_contexts_for_duplicates( + cls, events_and_contexts: List[Tuple[EventBase, EventContext]] + ) -> List[Tuple[EventBase, EventContext]]: """Ensure that we don't have the same event twice. Pick the earliest non-outlier if there is one, else the earliest one. @@ -641,7 +643,9 @@ class PersistEventsStore: Returns: list[(EventBase, EventContext)]: filtered list """ - new_events_and_contexts = OrderedDict() + new_events_and_contexts = ( + OrderedDict() + ) # type: OrderedDict[str, Tuple[EventBase, EventContext]] for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) if prev_event_context: @@ -655,7 +659,12 @@ class PersistEventsStore: new_events_and_contexts[event.event_id] = (event, context) return list(new_events_and_contexts.values()) - def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + def _update_room_depths_txn( + self, + txn, + events_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + ): """Update min_depth for each room Args: @@ -664,7 +673,7 @@ class PersistEventsStore: we are persisting backfilled (bool): True if the events were backfilled """ - depth_updates = {} + depth_updates = {} # type: Dict[str, int] for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids txn.call_after(self.store._invalidate_get_event_cache, event.event_id) @@ -1436,7 +1445,7 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) -- cgit 1.5.1 From 04cc249b43e8716513f788b2a4eeb8ede24d19df Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 14 Sep 2020 10:16:41 +0100 Subject: Add experimental support for sharding event persister. Again. (#8294) This is *not* ready for production yet. Caveats: 1. We should write some tests... 2. The stream token that we use for events can get stalled at the minimum position of all writers. This means that new events may not be processed and e.g. sent down sync streams if a writer isn't writing or is slow. --- changelog.d/8294.feature | 1 + synapse/config/_base.py | 21 ++++++- synapse/config/_base.pyi | 1 + synapse/config/workers.py | 37 ++++++++---- synapse/handlers/federation.py | 44 ++++++++++----- synapse/handlers/message.py | 14 +++-- synapse/handlers/room.py | 14 +++-- synapse/handlers/room_member.py | 7 --- synapse/replication/http/federation.py | 12 +++- synapse/replication/tcp/handler.py | 2 +- synapse/replication/tcp/streams/events.py | 4 +- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 2 +- synapse/storage/databases/main/events.py | 12 ++-- synapse/storage/databases/main/events_worker.py | 66 +++++++++++++++------- .../schema/delta/58/14events_instance_name.sql | 16 ++++++ .../delta/58/14events_instance_name.sql.postgres | 26 +++++++++ synapse/storage/util/id_generators.py | 10 ++-- 18 files changed, 211 insertions(+), 80 deletions(-) create mode 100644 changelog.d/8294.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql create mode 100644 synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres (limited to 'synapse/storage/databases/main/events.py') diff --git a/changelog.d/8294.feature b/changelog.d/8294.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8294.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ad5ab6ad62..bb9bf8598d 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c195eba830..a5734bebab 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -896,7 +896,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1189,7 +1190,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1336,15 +1337,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1593,7 +1594,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1620,7 +1621,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1868,7 +1871,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1881,6 +1884,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1912,6 +1916,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1922,6 +1927,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1936,6 +1942,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2010,17 +2017,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2871,6 +2881,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2878,14 +2889,19 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e54e2b322b..a8fe5cf4e2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -376,9 +376,8 @@ class EventCreationHandler: self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -902,9 +901,10 @@ class EventCreationHandler: try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -972,8 +972,10 @@ class EventCreationHandler: This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer assert self.storage.persistence is not None + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 53d85ab97d..eeade6ad3f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1259,10 +1261,10 @@ class RoomShutdownHandler: # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1292,7 +1294,9 @@ class RoomShutdownHandler: # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 100f335b80..01a6e88262 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,13 +82,6 @@ class RoomMemberHandler: self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 6b56315148..5c8be747e1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f929fc3954..ccc7ca30d8 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 985b12df91..aa5d490624 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -75,7 +75,7 @@ class Databases: # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0b69aa6a94..4c3c162acf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9cd1403b38..9a80f419e3 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -32,7 +32,7 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter @@ -97,18 +97,21 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID # generators are chained off them so doing so is a bit of a PITA. - self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator - self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator + self._backfill_id_gen = ( + self.store._backfill_id_gen + ) # type: MultiWriterIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -809,6 +812,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7a73cc3d8..17f5997b89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644 index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql @@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644 index 0000000000..97c1e6a0c5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +SELECT setval('events_backfill_stream_seq', ( + SELECT COALESCE(-MIN(stream_ordering), 1) FROM events +)); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 2a66b3ad4e..1de2b91587 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -240,8 +240,12 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] @@ -398,9 +402,7 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. -- cgit 1.5.1 From cbabb312e0b59090e5a8cf9e7e016a8618e62867 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Sep 2020 16:11:18 +0100 Subject: Use `async with` for ID gens (#8383) This will allow us to hit the DB after we've finished using the generated stream ID. --- changelog.d/8383.misc | 1 + synapse/storage/databases/main/account_data.py | 4 +- synapse/storage/databases/main/deviceinbox.py | 4 +- synapse/storage/databases/main/devices.py | 6 +- synapse/storage/databases/main/end_to_end_keys.py | 2 +- synapse/storage/databases/main/events.py | 6 +- synapse/storage/databases/main/group_server.py | 2 +- synapse/storage/databases/main/presence.py | 4 +- synapse/storage/databases/main/push_rule.py | 8 +- synapse/storage/databases/main/pusher.py | 4 +- synapse/storage/databases/main/receipts.py | 2 +- synapse/storage/databases/main/room.py | 6 +- synapse/storage/databases/main/tags.py | 4 +- synapse/storage/util/id_generators.py | 130 +++++++++++++--------- tests/storage/test_id_generators.py | 66 ++++++----- 15 files changed, 144 insertions(+), 105 deletions(-) create mode 100644 changelog.d/8383.misc (limited to 'synapse/storage/databases/main/events.py') diff --git a/changelog.d/8383.misc b/changelog.d/8383.misc new file mode 100644 index 0000000000..cb8318bf57 --- /dev/null +++ b/changelog.d/8383.misc @@ -0,0 +1 @@ +Refactor ID generators to use `async with` syntax. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c5a36990e4..ef81d73573 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index e71217a41f..d42faa3f1f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index c04374e43d..fdf394c612 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with await self._device_list_id_gen.get_next() as stream_id: + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c8df0bcb3f..22e1ed15d0 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key (dict): the key data """ - with await self._cross_signing_id_gen.get_next() as stream_id: + async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9a80f419e3..7723d82496 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -156,15 +156,15 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = await self._backfill_id_gen.get_next_mult( + stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = await self._stream_id_gen.get_next_mult( + stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index ccfbb2135e..7218191965 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with await self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c9f655dfb7..dbbb99cb95 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = await self._presence_id_gen.get_next_mult( + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e20a16f907..711d5aa23d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore): Raises: NotFoundError if the rule does not exist. """ - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c388468273..df8609b97b 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index f880b5e562..c79ddff680 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - with await self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3ee097abf7..3c7630857f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 96ffe26cc9..9f120d3cb6 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1de2b91587..b0353ac2dc 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque from synapse.storage.database import DatabasePool, LoggingTransaction @@ -86,7 +86,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +101,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +113,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +121,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +140,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +149,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -282,59 +282,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - - self._unfinished_ids.add(next_id) - - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) - return manager() + return _MultiWriterCtxManager(self) - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -482,3 +446,61 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self.id_gen._lock: + assert max(self.id_gen._current_positions.values(), default=0) < min( + self.stream_ids + ) + + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + return False diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 20636fc400..fb8f5bc255 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await id_gen.get_next() as stream_id: + async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) @@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ctx3 = self.get_success(id_gen.get_next()) ctx4 = self.get_success(id_gen.get_next()) - s1 = ctx1.__enter__() - s2 = ctx2.__enter__() - s3 = ctx3.__enter__() - s4 = ctx4.__enter__() + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + s3 = self.get_success(ctx3.__aenter__()) + s4 = self.get_success(ctx4.__aenter__()) self.assertEqual(s1, 8) self.assertEqual(s2, 9) @@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx2.__exit__(None, None, None) + self.get_success(ctx2.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx1.__exit__(None, None, None) + self.get_success(ctx1.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx4.__exit__(None, None, None) + self.get_success(ctx4.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx3.__exit__(None, None, None) + self.get_success(ctx3.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) @@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await first_id_gen.get_next() as stream_id: + async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual( @@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # stream ID async def _get_next_async(): - with await second_id_gen.get_next() as stream_id: + async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) self.assertEqual( @@ -305,9 +305,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 3) - with self.get_success(id_gen.get_next()) as stream_id: - self.assertEqual(stream_id, 6) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_persisted_upto_position(), 6) @@ -373,16 +377,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ id_gen = self._create_id_generator() - with self.get_success(id_gen.get_next()) as stream_id: - self._insert_row("master", stream_id) + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": -1}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - with self.get_success(id_gen.get_next_mult(3)) as stream_ids: - for stream_id in stream_ids: - self._insert_row("master", stream_id) + async def _get_next_async2(): + async with id_gen.get_next_mult(3) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen.get_positions(), {"master": -4}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) @@ -402,18 +412,24 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen_1 = self._create_id_generator("first") id_gen_2 = self._create_id_generator("second") - with self.get_success(id_gen_1.get_next()) as stream_id: - self._insert_row("first", stream_id) - id_gen_2.advance("first", stream_id) + async def _get_next_async(): + async with id_gen_1.get_next() as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - with self.get_success(id_gen_2.get_next()) as stream_id: - self._insert_row("second", stream_id) - id_gen_1.advance("second", stream_id) + async def _get_next_async2(): + async with id_gen_2.get_next() as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) -- cgit 1.5.1 From 302dc89f6a16f69e076943cb0a9b94f1e41741f9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Sep 2020 16:42:14 +0100 Subject: Fix bug which caused failure on join with malformed membership events (#8385) --- changelog.d/8385.bugfix | 1 + synapse/storage/databases/main/events.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8385.bugfix (limited to 'synapse/storage/databases/main/events.py') diff --git a/changelog.d/8385.bugfix b/changelog.d/8385.bugfix new file mode 100644 index 0000000000..c42502a8e0 --- /dev/null +++ b/changelog.d/8385.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7723d82496..18def01f50 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,7 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -1108,6 +1108,10 @@ class PersistEventsStore: def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ + + def str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) else None + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", @@ -1118,8 +1122,8 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), + "display_name": str_or_none(event.content.get("displayname")), + "avatar_url": str_or_none(event.content.get("avatar_url")), } for event in events ], -- cgit 1.5.1