diff options
author | Will Hunt <will@half-shot.uk> | 2020-09-28 11:11:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-28 11:11:57 +0100 |
commit | 6201ea56eeae22258d58db2f7381f96759f5e85e (patch) | |
tree | 37cb630a46fe96e416110de253ad50d256c0482e /synapse/storage | |
parent | Add support for device messages, start support for device lists (diff) | |
parent | typo (diff) | |
download | synapse-6201ea56eeae22258d58db2f7381f96759f5e85e.tar.xz |
Merge branch 'develop' into hs/super-wip-edus-down-sync
Diffstat (limited to 'synapse/storage')
24 files changed, 410 insertions, 121 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index ccb3384db9..0cb12f4c61 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,14 +160,20 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, - instance_name="master", + stream_name="caches", + instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c5a36990e4..ef81d73573 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index e4fd979a33..2d151b9134 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -394,7 +394,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -443,7 +443,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index c04374e43d..fdf394c612 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with await self._device_list_id_gen.get_next() as stream_id: + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c8df0bcb3f..22e1ed15d0 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key (dict): the key data """ - with await self._cross_signing_id_gen.get_next() as stream_id: + async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9a80f419e3..18def01f50 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,7 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -156,15 +156,15 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = await self._backfill_id_gen.get_next_mult( + stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = await self._stream_id_gen.get_next_mult( + stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream @@ -1108,6 +1108,10 @@ class PersistEventsStore: def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ + + def str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) else None + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", @@ -1118,8 +1122,8 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), + "display_name": str_or_none(event.content.get("displayname")), + "avatar_url": str_or_none(event.content.get("avatar_url")), } for event in events ], diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index de9e8d1dc6..f95679ebc4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="events", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, ) self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="backfill", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_backfill_stream_seq", positive=False, + writers=hs.config.worker.writers.events, ) else: # We shouldn't be running in worker mode with SQLite, but its useful diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index ccfbb2135e..7218191965 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with await self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c9f655dfb7..dbbb99cb95 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = await self._presence_id_gen.get_next_mult( + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e20a16f907..711d5aa23d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore): Raises: NotFoundError if the rule does not exist. """ - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c388468273..df8609b97b 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 5867d52b62..c10a16ffa3 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -577,7 +577,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - with await self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 675e81fe34..48ce7ecd16 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_expiration_ts_for_user", ) + async def is_account_expired(self, user_id: str, current_ts: int) -> bool: + """ + Returns whether an user account is expired. + + Args: + user_id: The user's ID + current_ts: The current timestamp + + Returns: + Whether the user account has expired + """ + expiration_ts = await self.get_expiration_ts_for_user(user_id) + return expiration_ts is not None and current_ts >= expiration_ts + async def set_account_validity_for_user( self, user_id: str, @@ -379,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -387,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index bd6f9553c6..3c7630857f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, @@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: str = "b", + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + event_reports: json list of event reports + count: total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn(txn): + filters = [] + args = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.reason, + er.content, + events.sender, + room_aliases.room_alias, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN room_aliases + ON room_aliases.room_id = er.room_id + JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + event_reports = self.db_pool.cursor_to_dict(txn) + + if count > 0: + for row in event_reports: + try: + row["content"] = db_to_json(row["content"]) + row["event_json"] = db_to_json(row["event_json"]) + except Exception: + continue + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4fa8767b01..86ffe2479e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set @@ -37,7 +36,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # for rooms the server is participating in. if self._current_state_events_membership_up_to_date: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ else: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN room_memberships AS m USING (room_id, event_id) INNER JOIN events AS e USING (room_id, event_id) @@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + return frozenset( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + for room_id, instance, stream_id in txn + ) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql index 5e29c1da19..ccf287971c 100644 --- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql @@ -13,7 +13,7 @@ * limitations under the License. */ --- room_id and topoligical_ordering are denormalised from the events table in order to +-- room_id and topological_ordering are denormalised from the events table in order to -- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres index 97c1e6a0c5..c31f9af82a 100644 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', ( CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; +-- If the server has never backfilled a room then doing `-MIN(...)` will give +-- a negative result, hence why we do `GREATEST(...)` SELECT setval('events_backfill_stream_seq', ( - SELECT COALESCE(-MIN(stream_ordering), 1) FROM events + SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events )); diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql new file mode 100644 index 0000000000..985fd949a2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d7816a8606..5beb302be3 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -210,6 +210,7 @@ class StatsStore(StateDeltasStore): * topic * avatar * canonical_alias + * guest_access A is_federatable key can also be included with a boolean value. @@ -234,6 +235,7 @@ class StatsStore(StateDeltasStore): "topic", "avatar", "canonical_alias", + "guest_access", ): field = fields.get(col, sentinel) if field is not sentinel and (not isinstance(field, str) or "\0" in field): diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 96ffe26cc9..9f120d3cb6 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 2f7c95fc74..f9575b1f1f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore): return # They are there, delete them. - self.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "erased_users", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index d89f6ed128..603cd7d825 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import Collection, StateMap +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -190,6 +190,7 @@ class EventsPersistenceStorage: self.persist_events_store = stores.persist_events self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() @@ -198,7 +199,7 @@ class EventsPersistenceStorage: self, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> int: + ) -> RoomStreamToken: """ Write events to the database Args: @@ -228,11 +229,11 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return self.main_store.get_current_events_token() + return RoomStreamToken(None, self.main_store.get_current_events_token()) async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[int, int]: + ) -> Tuple[PersistedEventPosition, RoomStreamToken]: """ Returns: The stream ordering of `event`, and the stream ordering of the @@ -247,7 +248,10 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) max_persisted_id = self.main_store.get_current_events_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) + event_stream_id = event.internal_metadata.stream_ordering + + pos = PersistedEventPosition(self._instance_name, event_stream_id) + return pos, RoomStreamToken(None, max_persisted_id) def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8c4a83a840..f152f63321 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -25,7 +25,7 @@ RoomsForUser = namedtuple( ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") + "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1de2b91587..4269eaf918 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -12,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -86,7 +87,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +102,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +114,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +122,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +141,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +150,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -184,12 +185,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +203,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,9 +225,7 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). @@ -251,30 +258,84 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - + ): cur = db_conn.cursor() - cur.execute(sql) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + # Load the current positions of all writers for the stream. + if self._writers: + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + sql = self._db.engine.convert_param_style(sql) - cur.close() + cur.execute(sql, (self._stream_name,)) - return current_positions + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } + + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + # We add a GREATEST here to ensure that the result is always + # positive. (This can be a problem for e.g. backfill streams where + # the server has never backfilled). + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (min_stream_id,)) + + self._persisted_upto_position = min_stream_id + + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) + + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -282,59 +343,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - - self._unfinished_ids.add(next_id) - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) + return _MultiWriterCtxManager(self) - return manager() - - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -352,6 +377,21 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): @@ -482,3 +522,95 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + def _update_stream_positions_table_txn(self, txn): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self.id_gen._lock: + assert max(self.id_gen._current_positions.values(), default=0) < min( + self.stream_ids + ) + + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + ) + + return False |