From bd380d942fdf91cf1214d6859f2bc97d12a92ab4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 28 Sep 2020 18:00:30 +0100 Subject: Add checks for postgres sequence consistency (#8402) --- synapse/storage/databases/main/registration.py | 3 + synapse/storage/databases/state/store.py | 3 + synapse/storage/util/id_generators.py | 5 ++ synapse/storage/util/sequence.py | 90 +++++++++++++++++++++++++- 4 files changed, 99 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 48ce7ecd16..a83df7759d 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -41,6 +41,9 @@ class RegistrationWorkerStore(SQLBaseStore): self.config = hs.config self.clock = hs.get_clock() + # Note: we don't check this sequence for consistency as we'd have to + # call `find_max_generated_user_id_localpart` each time, which is + # expensive if there are many entries. self._user_id_seq = build_sequence_generator( database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index bec3780a32..989f0cbc9d 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_seq_gen = build_sequence_generator( self.database_engine, get_max_state_group_txn, "state_group_id_seq" ) + self._state_group_seq_gen.check_consistency( + db_conn, table="state_groups", id_column="id" + ) @cached(max_entries=10000, iterable=True) async def get_state_group_delta(self, state_group): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4269eaf918..4fd7573e26 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -258,6 +258,11 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # We check that the table and sequence haven't diverged. + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) + # This goes and fills out the above state from the database. self._load_current_ids(db_conn, table, instance_column, id_column) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index ffc1894748..2dd95e2709 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -13,11 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import threading from typing import Callable, List, Optional -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.storage.types import Cursor +from synapse.storage.engines import ( + BaseDatabaseEngine, + IncorrectDatabaseSetup, + PostgresEngine, +) +from synapse.storage.types import Connection, Cursor + +logger = logging.getLogger(__name__) + + +_INCONSISTENT_SEQUENCE_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated +table '%(table)s'. This can happen if Synapse has been downgraded and +then upgraded again, or due to a bad migration. + +To fix this error, shut down Synapse (including any and all workers) +and run the following SQL: + + SELECT setval('%(seq)s', ( + %(max_id_sql)s + )); + +See docs/postgres.md for more information. +""" class SequenceGenerator(metaclass=abc.ABCMeta): @@ -28,6 +51,19 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """Gets the next ID in the sequence""" ... + @abc.abstractmethod + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + """Should be called during start up to test that the current value of + the sequence is greater than or equal to the maximum ID in the table. + + This is to handle various cases where the sequence value can get out + of sync with the table, e.g. if Synapse gets rolled back to a previous + version and the rolled forwards again. + """ + ... + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -45,6 +81,50 @@ class PostgresSequenceGenerator(SequenceGenerator): ) return [i for (i,) in txn] + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + txn = db_conn.cursor() + + # First we get the current max ID from the table. + table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { + "id": id_column, + "table": table, + "agg": "MAX" if positive else "-MIN", + } + + txn.execute(table_sql) + row = txn.fetchone() + if not row: + # Table is empty, so nothing to do. + txn.close() + return + + # Now we fetch the current value from the sequence and compare with the + # above. + max_stream_id = row[0] + txn.execute( + "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} + ) + last_value, is_called = txn.fetchone() + txn.close() + + # If `is_called` is False then `last_value` is actually the value that + # will be generated next, so we decrement to get the true "last value". + if not is_called: + last_value -= 1 + + if max_stream_id > last_value: + logger.warning( + "Postgres sequence %s is behind table %s: %d < %d", + last_value, + max_stream_id, + ) + raise IncorrectDatabaseSetup( + _INCONSISTENT_SEQUENCE_ERROR + % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -81,6 +161,12 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def check_consistency( + self, db_conn: Connection, table: str, id_column: str, positive: bool = True + ): + # There is nothing to do for in memory sequences + pass + def build_sequence_generator( database_engine: BaseDatabaseEngine, -- cgit 1.5.1 From 8676d8ab2e5667d7c12774effc64b3ab99344a8d Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Tue, 29 Sep 2020 13:11:02 +0100 Subject: Filter out appservices from mau count (#8404) This is an attempt to fix #8403. --- changelog.d/8404.misc | 1 + synapse/storage/databases/main/monthly_active_users.py | 9 ++++++++- tests/storage/test_monthly_active_users.py | 17 ++++++++++++++++- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8404.misc (limited to 'synapse/storage') diff --git a/changelog.d/8404.misc b/changelog.d/8404.misc new file mode 100644 index 0000000000..7aadded6c1 --- /dev/null +++ b/changelog.d/8404.misc @@ -0,0 +1 @@ +Do not include appservice users when calculating the total MAU for a server. diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e0cedd1aac..e93aad33cd 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -41,7 +41,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ def _count_users(txn): - sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + # Exclude app service users + sql = """ + SELECT COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users + ON monthly_active_users.user_id=users.name + WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); + """ txn.execute(sql) (count,) = txn.fetchone() return count diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 643072bbaf..8d97b6d4cd 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -137,6 +137,21 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 1) + def test_appservice_user_not_counted_in_mau(self): + self.get_success( + self.store.register_user( + user_id="@appservice_user:server", appservice_id="wibble" + ) + ) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + d = self.store.upsert_monthly_active_user("@appservice_user:server") + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" @@ -383,7 +398,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, 4) + self.assertEqual(count, 1) d = self.store.get_monthly_active_count_by_service() result = self.get_success(d) -- cgit 1.5.1 From 2649d545a551dd126d73d34a6e3172916ea483e0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Sep 2020 15:57:36 +0100 Subject: Mypy fixes for `synapse.handlers.federation` (#8422) For some reason, an apparently unrelated PR upset mypy about this module. Here are a number of little fixes. --- changelog.d/8422.misc | 1 + synapse/federation/federation_client.py | 4 +++- synapse/handlers/federation.py | 13 +++++++++---- synapse/storage/databases/state/store.py | 4 ++-- synapse/storage/persist_events.py | 2 +- synapse/storage/state.py | 6 +++--- 6 files changed, 19 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8422.misc (limited to 'synapse/storage') diff --git a/changelog.d/8422.misc b/changelog.d/8422.misc new file mode 100644 index 0000000000..03fba120c6 --- /dev/null +++ b/changelog.d/8422.misc @@ -0,0 +1 @@ +Typing fixes for `synapse.handlers.federation`. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 688d43fffb..302b2f69bc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -24,10 +24,12 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Sequence, Tuple, TypeVar, + Union, ) from prometheus_client import Counter @@ -501,7 +503,7 @@ class FederationClient(FederationBase): user_id: str, membership: str, content: dict, - params: Dict[str, str], + params: Optional[Mapping[str, Union[str, Iterable[str]]]], ) -> Tuple[str, EventBase, RoomVersion]: """ Creates an m.room.member event, with context, without participating in the room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5bcfb231b2..0073e7c996 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -155,8 +155,9 @@ class FederationHandler(BaseHandler): self._device_list_updater = hs.get_device_handler().device_list_updater self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite - # When joining a room we need to queue any events for that room up - self.room_queues = {} + # When joining a room we need to queue any events for that room up. + # For each room, a list of (pdu, origin) tuples. + self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]] self._room_pdu_linearizer = Linearizer("fed_room_pdu") self.third_party_event_rules = hs.get_third_party_event_rules() @@ -814,6 +815,9 @@ class FederationHandler(BaseHandler): dest, room_id, limit=limit, extremities=extremities ) + if not events: + return [] + # ideally we'd sanity check the events here for excess prev_events etc, # but it's hard to reject events at this point without completely # breaking backfill in the same way that it is currently broken by @@ -2164,10 +2168,10 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = await self.state_store.get_state_groups( + state_sets_d = await self.state_store.get_state_groups( event.room_id, extrem_ids ) - state_sets = list(state_sets.values()) + state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]] state_sets.append(state) current_states = await self.state_handler.resolve_events( room_version, state_sets, event @@ -2958,6 +2962,7 @@ class FederationHandler(BaseHandler): ) return result["max_stream_id"] else: + assert self.storage.persistence max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 989f0cbc9d..0e31cc811a 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -208,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 603cd7d825..ded6cf9655 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -197,7 +197,7 @@ class EventsPersistenceStorage: async def persist_events( self, - events_and_contexts: List[Tuple[EventBase, EventContext]], + events_and_contexts: Iterable[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> RoomStreamToken: """ diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 8f68d968f0..08a69f2f96 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -20,7 +20,7 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import StateMap +from synapse.types import MutableStateMap, StateMap logger = logging.getLogger(__name__) @@ -349,7 +349,7 @@ class StateGroupStorage: async def get_state_groups_ids( self, _room_id: str, event_ids: Iterable[str] - ) -> Dict[int, StateMap[str]]: + ) -> Dict[int, MutableStateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: @@ -532,7 +532,7 @@ class StateGroupStorage: def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ) -> Awaitable[Dict[int, StateMap[str]]]: + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key -- cgit 1.5.1 From b1433bf231370636b817ffa01e6cda5a567cfafe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Sep 2020 16:42:19 +0100 Subject: Don't table scan events on worker startup (#8419) * Fix table scan of events on worker startup. This happened because we assumed "new" writers had an initial stream position of 0, so the replication code tried to fetch all events written by the instance between 0 and the current position. Instead, set the initial position of new writers to the current persisted up to position, on the assumption that new writers won't have written anything before that point. * Consider old writers coming back as "new". Otherwise we'd try and fetch entries between the old stale token and the current position, even though it won't have written any rows. Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/8419.feature | 1 + synapse/storage/util/id_generators.py | 26 +++++++++++++++++++++++++- tests/storage/test_id_generators.py | 18 ++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8419.feature (limited to 'synapse/storage') diff --git a/changelog.d/8419.feature b/changelog.d/8419.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8419.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 4fd7573e26..02fbb656e8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -273,6 +273,19 @@ class MultiWriterIdGenerator: # Load the current positions of all writers for the stream. if self._writers: + # We delete any stale entries in the positions table. This is + # important if we add back a writer after a long time; we want to + # consider that a "new" writer, rather than using the old stale + # entry here. + sql = """ + DELETE FROM stream_positions + WHERE + stream_name = ? + AND instance_name != ALL(?) + """ + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (self._stream_name, self._writers)) + sql = """ SELECT instance_name, stream_id FROM stream_positions WHERE stream_name = ? @@ -453,11 +466,22 @@ class MultiWriterIdGenerator: """Returns the position of the given writer. """ + # If we don't have an entry for the given instance name, we assume it's a + # new writer. + # + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get( + instance_name, self._persisted_upto_position + ) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. + + Note that this won't necessarily include all configured writers if some + writers haven't written anything yet. """ with self._lock: diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 4558bee7be..392b08832b 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -390,17 +390,28 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Initial config has two writers id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_persisted_upto_position(), 3) + self.assertEqual(id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(id_gen.get_current_token_for_writer("second"), 5) # New config removes one of the configs. Note that if the writer is # removed from config we assume that it has been shut down and has # finished persisting, hence why the persisted upto position is 5. id_gen_2 = self._create_id_generator("second", writers=["second"]) self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + self.assertEqual(id_gen_2.get_current_token_for_writer("second"), 5) # This config points to a single, previously unused writer. id_gen_3 = self._create_id_generator("third", writers=["third"]) self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer comes along. + self.assertEqual(id_gen_3.get_current_token_for_writer("third"), 5) + + id_gen_4 = self._create_id_generator("fourth", writers=["third"]) + self.assertEqual(id_gen_4.get_current_token_for_writer("third"), 5) + # Check that we get a sane next stream ID with this new config. async def _get_next_async(): @@ -410,6 +421,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(_get_next_async()) self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + # If we add back the old "first" then we shouldn't see the persisted up + # to position revert back to 3. + id_gen_5 = self._create_id_generator("five", writers=["first", "third"]) + self.assertEqual(id_gen_5.get_persisted_upto_position(), 6) + self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) + self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) + def test_sequence_consistency(self): """Test that we error out if the table and sequence diverges. """ -- cgit 1.5.1 From ea70f1c362dc4bd6c0f8a67e16ed0971fe095e5b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Sep 2020 21:48:33 +0100 Subject: Various clean ups to room stream tokens. (#8423) --- changelog.d/8423.misc | 1 + synapse/events/__init__.py | 6 ++-- synapse/handlers/admin.py | 2 +- synapse/handlers/device.py | 4 +-- synapse/handlers/initial_sync.py | 3 +- synapse/handlers/pagination.py | 5 ++- synapse/handlers/room.py | 4 +-- synapse/handlers/sync.py | 20 ++++++++---- synapse/notifier.py | 4 +-- synapse/replication/tcp/client.py | 6 ++-- synapse/rest/admin/__init__.py | 3 +- synapse/storage/databases/main/stream.py | 38 +++++++++++++---------- synapse/storage/persist_events.py | 5 ++- synapse/types.py | 53 ++++++++++++++++++++------------ tests/rest/client/v1/test_rooms.py | 8 ++--- tests/storage/test_purge.py | 10 +++--- 16 files changed, 96 insertions(+), 76 deletions(-) create mode 100644 changelog.d/8423.misc (limited to 'synapse/storage') diff --git a/changelog.d/8423.misc b/changelog.d/8423.misc new file mode 100644 index 0000000000..7260e3fa41 --- /dev/null +++ b/changelog.d/8423.misc @@ -0,0 +1 @@ +Various refactors to simplify stream token handling. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index bf800a3852..dc49df0812 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -118,8 +118,8 @@ class _EventInternalMetadata: # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before = DictProperty("before") # type: str - after = DictProperty("after") # type: str + before = DictProperty("before") # type: RoomStreamToken + after = DictProperty("after") # type: RoomStreamToken order = DictProperty("order") # type: Tuple[int, int] def get_dict(self) -> JsonDict: diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index dd981c597e..1ce2091b46 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -153,7 +153,7 @@ class AdminHandler(BaseHandler): if not events: break - from_key = RoomStreamToken.parse(events[-1].internal_metadata.after) + from_key = events[-1].internal_metadata.after events = await filter_events_for_client(self.storage, user_id, events) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 4149520d6c..b9d9098104 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,7 +29,6 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( - RoomStreamToken, StreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -113,8 +112,7 @@ class DeviceWorkerHandler(BaseHandler): set_tag("user_id", user_id) set_tag("from_token", from_token) - now_room_id = self.store.get_room_max_stream_ordering() - now_room_key = RoomStreamToken(None, now_room_id) + now_room_key = self.store.get_room_max_token() room_ids = await self.store.get_rooms_for_user(user_id) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 8cd7eb22a3..43f15435de 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -325,7 +325,8 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = await self.store.get_stream_token_for_event(member_event_id) + leave_position = await self.store.get_position_for_event(member_event_id) + stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a0b3bdb5e0..d6779a4b44 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -25,7 +25,7 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import Requester, RoomStreamToken +from synapse.types import Requester from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -373,10 +373,9 @@ class PaginationHandler: # case "JOIN" would have been returned. assert member_event_id - leave_token_str = await self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) - leave_token = RoomStreamToken.parse(leave_token_str) assert leave_token.topological is not None if leave_token.topological < curr_topo: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 11bf146bed..836b3f381a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1134,14 +1134,14 @@ class RoomEventSource: events[:] = events[:limit] if events: - end_key = RoomStreamToken.parse(events[-1].internal_metadata.after) + end_key = events[-1].internal_metadata.after else: end_key = to_key return (events, end_key) def get_current_key(self) -> RoomStreamToken: - return RoomStreamToken(None, self.store.get_room_max_stream_ordering()) + return self.store.get_room_max_token() def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: return self.store.get_room_events_max_id(room_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e948efef2e..bfe2583002 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -519,7 +519,7 @@ class SyncHandler: if len(recents) > timeline_limit: limited = True recents = recents[-timeline_limit:] - room_key = RoomStreamToken.parse(recents[0].internal_metadata.before) + room_key = recents[0].internal_metadata.before prev_batch_token = now_token.copy_and_replace("room_key", room_key) @@ -1595,16 +1595,24 @@ class SyncHandler: if leave_events: leave_event = leave_events[-1] - leave_stream_token = await self.store.get_stream_token_for_event( + leave_position = await self.store.get_position_for_event( leave_event.event_id ) - leave_token = since_token.copy_and_replace( - "room_key", leave_stream_token - ) - if since_token and since_token.is_after(leave_token): + # If the leave event happened before the since token then we + # bail. + if since_token and not leave_position.persisted_after( + since_token.room_key + ): continue + # We can safely convert the position of the leave event into a + # stream token as it'll only be used in the context of this + # room. (c.f. the docstring of `to_room_stream_token`). + leave_token = since_token.copy_and_replace( + "room_key", leave_position.to_room_stream_token() + ) + # If this is an out of band message, like a remote invite # rejection, we include it in the recents batch. Otherwise, we # let _load_filtered_recents handle fetching the correct diff --git a/synapse/notifier.py b/synapse/notifier.py index 441b3d15e2..59415f6f88 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -163,7 +163,7 @@ class _NotifierUserStream: """ # Immediately wake up stream if something has already since happened # since their last token. - if self.last_notified_token.is_after(token): + if self.last_notified_token != token: return _NotificationListener(defer.succeed(self.current_token)) else: return _NotificationListener(self.notify_deferred.observe()) @@ -470,7 +470,7 @@ class Notifier: async def check_for_updates( before_token: StreamToken, after_token: StreamToken ) -> EventStreamResult: - if not after_token.is_after(before_token): + if after_token == before_token: return EventStreamResult([], (from_token, from_token)) events = [] # type: List[EventBase] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 55af3d41ea..e165429cad 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.types import PersistedEventPosition, RoomStreamToken, UserID +from synapse.types import PersistedEventPosition, UserID from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -152,9 +152,7 @@ class ReplicationDataHandler: if event.type == EventTypes.Member: extra_users = (UserID.from_string(event.state_key),) - max_token = RoomStreamToken( - None, self.store.get_room_max_stream_ordering() - ) + max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) self.notifier.on_new_room_event( event, event_pos, max_token, extra_users diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 5c5f00b213..ba53f66f02 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -109,7 +109,8 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - token = await self.store.get_topological_token_for_event(event_id) + room_token = await self.store.get_topological_token_for_event(event_id) + token = str(room_token) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 92e96468b4..37249f1e3f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -35,7 +35,6 @@ what sort order was used: - topological tokems: "t%d-%d", where the integers map to the topological and stream ordering columns respectively. """ - import abc import logging from collections import namedtuple @@ -54,7 +53,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine -from synapse.types import Collection, RoomStreamToken +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -305,6 +304,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): def get_room_min_stream_ordering(self) -> int: raise NotImplementedError() + def get_room_max_token(self) -> RoomStreamToken: + return RoomStreamToken(None, self.get_room_max_stream_ordering()) + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], @@ -611,26 +613,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): allow_none=allow_none, ) - async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken: - """The stream token for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream token. + async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: + """Get the persisted position for an event """ - stream_id = await self.get_stream_id_for_event(event_id) - return RoomStreamToken(None, stream_id) + row = await self.db_pool.simple_select_one( + table="events", + keyvalues={"event_id": event_id}, + retcols=("stream_ordering", "instance_name"), + desc="get_position_for_event", + ) + + return PersistedEventPosition( + row["instance_name"] or "master", row["stream_ordering"] + ) - async def get_topological_token_for_event(self, event_id: str) -> str: + async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken: """The stream token for an event Args: event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A "t%d-%d" topological token. + A `RoomStreamToken` topological token. """ row = await self.db_pool.simple_select_one( table="events", @@ -638,7 +642,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", ) - return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) + return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: """Gets the topological token in a room after or at the given stream @@ -687,8 +691,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: topo = None internal = event.internal_metadata - internal.before = str(RoomStreamToken(topo, stream - 1)) - internal.after = str(RoomStreamToken(topo, stream)) + internal.before = RoomStreamToken(topo, stream - 1) + internal.after = RoomStreamToken(topo, stream) internal.order = (int(topo) if topo else 0, int(stream)) async def get_events_around( diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index ded6cf9655..72939f3984 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -229,7 +229,7 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return RoomStreamToken(None, self.main_store.get_current_events_token()) + return self.main_store.get_room_max_token() async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False @@ -247,11 +247,10 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) - max_persisted_id = self.main_store.get_current_events_token() event_stream_id = event.internal_metadata.stream_ordering pos = PersistedEventPosition(self._instance_name, event_stream_id) - return pos, RoomStreamToken(None, max_persisted_id) + return pos, self.main_store.get_room_max_token() def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/types.py b/synapse/types.py index ec39f9e1e8..02bcc197ec 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -413,6 +413,18 @@ class RoomStreamToken: pass raise SynapseError(400, "Invalid token %r" % (string,)) + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + return RoomStreamToken(None, max_stream) + def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) @@ -458,31 +470,20 @@ class StreamToken: def room_stream_id(self): return self.room_key.stream - def is_after(self, other): - """Does this token contain events that the other doesn't?""" - return ( - (other.room_stream_id < self.room_stream_id) - or (int(other.presence_key) < int(self.presence_key)) - or (int(other.typing_key) < int(self.typing_key)) - or (int(other.receipt_key) < int(self.receipt_key)) - or (int(other.account_data_key) < int(self.account_data_key)) - or (int(other.push_rules_key) < int(self.push_rules_key)) - or (int(other.to_device_key) < int(self.to_device_key)) - or (int(other.device_list_key) < int(self.device_list_key)) - or (int(other.groups_key) < int(self.groups_key)) - ) - def copy_and_advance(self, key, new_value) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. """ - new_token = self.copy_and_replace(key, new_value) if key == "room_key": - new_id = new_token.room_stream_id - old_id = self.room_stream_id - else: - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) + new_token = self.copy_and_replace( + "room_key", self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + if old_id < new_id: return new_token else: @@ -509,6 +510,18 @@ class PersistedEventPosition: def persisted_after(self, token: RoomStreamToken) -> bool: return token.stream < self.stream + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarentees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 0a567b032f..a3287011e9 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -902,15 +902,15 @@ class RoomMessageListTestCase(RoomBase): # Send a first message in the room, which will be removed by the purge. first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] - first_token = self.get_success( - store.get_topological_token_for_event(first_event_id) + first_token = str( + self.get_success(store.get_topological_token_for_event(first_event_id)) ) # Send a second message in the room, which won't be removed, and which we'll # use as the marker to purge events before. second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] - second_token = self.get_success( - store.get_topological_token_for_event(second_event_id) + second_token = str( + self.get_success(store.get_topological_token_for_event(second_event_id)) ) # Send a third event in the room to ensure we don't fall under any edge case diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 918387733b..723cd28933 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -47,8 +47,8 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = self.get_success( - store.get_topological_token_for_event(last["event_id"]) + event = str( + self.get_success(store.get_topological_token_for_event(last["event_id"])) ) # Purge everything before this topological token @@ -74,12 +74,10 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = self.get_success( + token = self.get_success( storage.get_topological_token_for_event(last["event_id"]) ) - event = "t{}-{}".format( - *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) - ) + event = "t{}-{}".format(token.topological + 1, token.stream + 1) # Purge everything before this topological token purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) -- cgit 1.5.1 From 6d2d42f8fb04599713d3e6e7fc3bc4c9b7063c9a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 29 Sep 2020 22:26:28 +0100 Subject: Rewrite BucketCollector This was a bit unweildy for what I wanted: in particular, I wanted to assign each measurement straight into a bucket, rather than storing an intermediate Counter which didn't do any bucketing at all. I've replaced it with something that is hopefully a bit easier to use. (I'm not entirely sure what the difference between a HistogramMetricFamily and a GaugeHistogramMetricFamily is, but given our counters can go down as well as up the latter *sounds* more accurate?) --- synapse/metrics/__init__.py | 115 ++++++++++++++++++------------ synapse/storage/databases/main/metrics.py | 26 +++---- tests/storage/test_event_metrics.py | 19 ++--- 3 files changed, 89 insertions(+), 71 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index a1f7ca3449..b8d2a8e8a9 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -15,6 +15,7 @@ import functools import gc +import itertools import logging import os import platform @@ -27,8 +28,8 @@ from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import ( REGISTRY, CounterMetricFamily, + GaugeHistogramMetricFamily, GaugeMetricFamily, - HistogramMetricFamily, ) from twisted.internet import reactor @@ -46,7 +47,7 @@ logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" running_on_pypy = platform.python_implementation() == "PyPy" -all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]] +all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") @@ -205,63 +206,83 @@ class InFlightGauge: all_gauges[self.name] = self -@attr.s(slots=True, hash=True) -class BucketCollector: - """ - Like a Histogram, but allows buckets to be point-in-time instead of - incrementally added to. +class GaugeBucketCollector: + """Like a Histogram, but the buckets are Gauges which are updated atomically. - Args: - name (str): Base name of metric to be exported to Prometheus. - data_collector (callable -> dict): A synchronous callable that - returns a dict mapping bucket to number of items in the - bucket. If these buckets are not the same as the buckets - given to this class, they will be remapped into them. - buckets (list[float]): List of floats/ints of the buckets to - give to Prometheus. +Inf is ignored, if given. + The data is updated by calling `update_data` with an iterable of measurements. + We assume that the data is updated less frequently than it is reported to + Prometheus, and optimise for that case. """ - name = attr.ib() - data_collector = attr.ib() - buckets = attr.ib() + __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric") - def collect(self): + def __init__( + self, + name: str, + documentation: str, + buckets: Iterable[float], + registry=REGISTRY, + ): + """ + Args: + name: base name of metric to be exported to Prometheus. (a _bucket suffix + will be added.) + documentation: help text for the metric + buckets: The top bounds of the buckets to report + registry: metric registry to register with + """ + self._name = name + self._documentation = documentation - # Fetch the data -- this must be synchronous! - data = self.data_collector() + # the tops of the buckets + self._bucket_bounds = [float(b) for b in buckets] + if self._bucket_bounds != sorted(self._bucket_bounds): + raise ValueError("Buckets not in sorted order") - buckets = {} # type: Dict[float, int] + if self._bucket_bounds[-1] != float("inf"): + self._bucket_bounds.append(float("inf")) - res = [] - for x in data.keys(): - for i, bound in enumerate(self.buckets): - if x <= bound: - buckets[bound] = buckets.get(bound, 0) + data[x] + self._metric = self._values_to_metric([]) + registry.register(self) - for i in self.buckets: - res.append([str(i), buckets.get(i, 0)]) + def collect(self): + yield self._metric - res.append(["+Inf", sum(data.values())]) + def update_data(self, values: Iterable[float]): + """Update the data to be reported by the metric - metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) + The existing data is cleared, and each measurement in the input is assigned + to the relevant bucket. + """ + self._metric = self._values_to_metric(values) + + def _values_to_metric(self, values: Iterable[float]) -> GaugeHistogramMetricFamily: + total = 0.0 + bucket_values = [0 for _ in self._bucket_bounds] + + for v in values: + # assign each value to a bucket + for i, bound in enumerate(self._bucket_bounds): + if v <= bound: + bucket_values[i] += 1 + break + + # ... and increment the sum + total += v + + # now, aggregate the bucket values so that they count the number of entries in + # that bucket or below. + accumulated_values = itertools.accumulate(bucket_values) + + return GaugeHistogramMetricFamily( + self._name, + self._documentation, + buckets=list( + zip((str(b) for b in self._bucket_bounds), accumulated_values) + ), + gsum_value=total, ) - yield metric - - def __attrs_post_init__(self): - self.buckets = [float(x) for x in self.buckets if x != "+Inf"] - if self.buckets != sorted(self.buckets): - raise ValueError("Buckets not sorted") - - self.buckets = tuple(self.buckets) - - if self.name in all_gauges.keys(): - logger.warning("%s already registered, reregistering" % (self.name,)) - REGISTRY.unregister(all_gauges.pop(self.name)) - - REGISTRY.register(self) - all_gauges[self.name] = self # diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 686052bd83..4efc093b9e 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,10 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import typing -from collections import Counter -from synapse.metrics import BucketCollector +from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -23,6 +21,14 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +# Collect metrics on the number of forward extremities that exist. +_extremities_collecter = GaugeBucketCollector( + "synapse_forward_extremities", + "Number of rooms on the server with the given number of forward extremities" + " or fewer", + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500], +) + class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """Functions to pull various metrics from the DB, for e.g. phone home @@ -32,18 +38,6 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = ( - Counter() - ) # type: typing.Counter[int] - - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) - # Read the extrems every 60 minutes def read_forward_extremities(): # run as a background process to make sure that the database transactions @@ -65,7 +59,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return txn.fetchall() res = await self.db_pool.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = Counter([x[0] for x in res]) + _extremities_collecter.update_data(x[0] for x in res) async def count_daily_messages(self): """ diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 949846fe33..3957471f3f 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -52,14 +52,14 @@ class ExtremStatisticsTestCase(HomeserverTestCase): self.reactor.advance(60 * 60 * 1000) self.pump(1) - items = set( + items = list( filter( lambda x: b"synapse_forward_extremities_" in x, - generate_latest(REGISTRY).split(b"\n"), + generate_latest(REGISTRY, emit_help=False).split(b"\n"), ) ) - expected = { + expected = [ b'synapse_forward_extremities_bucket{le="1.0"} 0.0', b'synapse_forward_extremities_bucket{le="2.0"} 2.0', b'synapse_forward_extremities_bucket{le="3.0"} 2.0', @@ -72,9 +72,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase): b'synapse_forward_extremities_bucket{le="100.0"} 3.0', b'synapse_forward_extremities_bucket{le="200.0"} 3.0', b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - } - + # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9, + # "inf" is valid: "this includes variants such as inf" + b'synapse_forward_extremities_bucket{le="inf"} 3.0', + b"# TYPE synapse_forward_extremities_gcount gauge", + b"synapse_forward_extremities_gcount 3.0", + b"# TYPE synapse_forward_extremities_gsum gauge", + b"synapse_forward_extremities_gsum 10.0", + ] self.assertEqual(items, expected) -- cgit 1.5.1 From 20e7c4de262746479000ec507b7a3c37f1779a60 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 29 Sep 2020 22:30:00 +0100 Subject: Add an improved "forward extremities" metric Hopefully, N(extremities) * N(state_events) is a more realistic approximation to "how big a problem is this room?". --- synapse/storage/databases/main/metrics.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 4efc093b9e..92099f95ce 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -29,6 +29,18 @@ _extremities_collecter = GaugeBucketCollector( buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500], ) +# we also expose metrics on the "number of excess extremity events", which is +# (E-1)*N, where E is the number of extremities and N is the number of state +# events in the room. This is an approximation to the number of state events +# we could remove from state resolution by reducing the graph to a single +# forward extremity. +_excess_state_events_collecter = GaugeBucketCollector( + "synapse_excess_extremity_events", + "Number of rooms on the server with the given number of excess extremity " + "events, or fewer", + buckets=[0] + [1 << n for n in range(12)], +) + class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): """Functions to pull various metrics from the DB, for e.g. phone home @@ -52,15 +64,26 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): def fetch(txn): txn.execute( """ - select count(*) c from event_forward_extremities - group by room_id + SELECT t1.c, t2.c + FROM ( + SELECT room_id, COUNT(*) c FROM event_forward_extremities + GROUP BY room_id + ) t1 LEFT JOIN ( + SELECT room_id, COUNT(*) c FROM current_state_events + GROUP BY room_id + ) t2 ON t1.room_id = t2.room_id """ ) return txn.fetchall() res = await self.db_pool.runInteraction("read_forward_extremities", fetch) + _extremities_collecter.update_data(x[0] for x in res) + _excess_state_events_collecter.update_data( + (x[0] - 1) * x[1] for x in res if x[1] + ) + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. -- cgit 1.5.1 From 7941372ec84786f85ae6d75fd2d7a4af5b72ac98 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Sep 2020 20:29:19 +0100 Subject: Make token serializing/deserializing async (#8427) The idea is that in future tokens will encode a mapping of instance to position. However, we don't want to include the full instance name in the string representation, so instead we'll have a mapping between instance name and an immutable integer ID in the DB that we can use instead. We'll then do the lookup when we serialize/deserialize the token (we could alternatively pass around an `Instance` type that includes both the name and ID, but that turns out to be a lot more invasive). --- changelog.d/8427.misc | 1 + synapse/handlers/events.py | 4 +-- synapse/handlers/initial_sync.py | 14 ++++----- synapse/handlers/pagination.py | 8 ++--- synapse/handlers/room.py | 8 +++-- synapse/handlers/search.py | 8 ++--- synapse/rest/admin/__init__.py | 2 +- synapse/rest/client/v1/events.py | 3 +- synapse/rest/client/v1/initial_sync.py | 3 +- synapse/rest/client/v1/room.py | 11 +++++-- synapse/rest/client/v2_alpha/keys.py | 3 +- synapse/rest/client/v2_alpha/sync.py | 10 +++--- synapse/storage/databases/main/purge_events.py | 8 ++--- synapse/streams/config.py | 9 +++--- synapse/types.py | 43 +++++++++++++++++++++----- tests/rest/client/v1/test_rooms.py | 30 +++++++++++++----- tests/storage/test_purge.py | 9 ++++-- 17 files changed, 115 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8427.misc (limited to 'synapse/storage') diff --git a/changelog.d/8427.misc b/changelog.d/8427.misc new file mode 100644 index 0000000000..c9656b9112 --- /dev/null +++ b/changelog.d/8427.misc @@ -0,0 +1 @@ +Make stream token serializing/deserializing async. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 0875b74ea8..539b4fc32e 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler): chunk = { "chunk": chunks, - "start": tokens[0].to_string(), - "end": tokens[1].to_string(), + "start": await tokens[0].to_string(self.store), + "end": await tokens[1].to_string(self.store), } return chunk diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 43f15435de..39a85801c1 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler): messages, time_now=time_now, as_client_event=as_client_event ) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), } d["state"] = await self._event_serializer.serialize_events( @@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler): ], "account_data": account_data_events, "receipts": receipt, - "end": now_token.to_string(), + "end": await now_token.to_string(self.store), } return ret @@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": ( await self._event_serializer.serialize_events( @@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler): "chunk": ( await self._event_serializer.serialize_events(messages, time_now) ), - "start": start_token.to_string(), - "end": end_token.to_string(), + "start": await start_token.to_string(self.store), + "end": await end_token.to_string(self.store), }, "state": state, "presence": presence, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index d6779a4b44..2c2a633938 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -413,8 +413,8 @@ class PaginationHandler: if not events: return { "chunk": [], - "start": from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } state = None @@ -442,8 +442,8 @@ class PaginationHandler: events, time_now, as_client_event=as_client_event ) ), - "start": from_token.to_string(), - "end": next_token.to_string(), + "start": await from_token.to_string(self.store), + "end": await next_token.to_string(self.store), } if state: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 836b3f381a..d5f7c78edf 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1077,11 +1077,13 @@ class RoomContextHandler: # the token, which we replace. token = StreamToken.START - results["start"] = token.copy_and_replace( + results["start"] = await token.copy_and_replace( "room_key", results["start"] - ).to_string() + ).to_string(self.store) - results["end"] = token.copy_and_replace("room_key", results["end"]).to_string() + results["end"] = await token.copy_and_replace( + "room_key", results["end"] + ).to_string(self.store) return results diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6a76c20d79..e9402e6e2e 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -362,13 +362,13 @@ class SearchHandler(BaseHandler): self.storage, user.to_string(), res["events_after"] ) - res["start"] = now_token.copy_and_replace( + res["start"] = await now_token.copy_and_replace( "room_key", res["start"] - ).to_string() + ).to_string(self.store) - res["end"] = now_token.copy_and_replace( + res["end"] = await now_token.copy_and_replace( "room_key", res["end"] - ).to_string() + ).to_string(self.store) if include_profile: senders = { diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ba53f66f02..57cac22252 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -110,7 +110,7 @@ class PurgeHistoryRestServlet(RestServlet): raise SynapseError(400, "Event is for wrong room.") room_token = await self.store.get_topological_token_for_event(event_id) - token = str(room_token) + token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 985d994f6b..1ecb77aa26 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -33,6 +33,7 @@ class EventStreamRestServlet(RestServlet): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet): if b"room_id" in request.args: room_id = request.args[b"room_id"][0].decode("ascii") - pagin_config = PaginationConfig.from_request(request) + pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in request.args: try: diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index d7042786ce..91da0ee573 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -27,11 +27,12 @@ class InitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 7e64a2e0fe..b63389e5fe 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -451,6 +451,7 @@ class RoomMemberListRestServlet(RestServlet): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) @@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet): if at_token_string is None: at_token = None else: - at_token = StreamToken.from_string(at_token_string) + at_token = await StreamToken.from_string(self.store, at_token_string) # let you filter down on particular memberships. # XXX: this may not be the best shape for this API - we could pass in a filter @@ -521,10 +522,13 @@ class RoomMessageListRestServlet(RestServlet): super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request, default_limit=10) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) as_client_event = b"raw" not in request.args filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: @@ -580,10 +584,11 @@ class RoomInitialSyncRestServlet(RestServlet): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 7abd6ff333..55c4606569 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -180,6 +180,7 @@ class KeyChangesServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet): # changes after the "to" as well as before. set_tag("to", parse_string(request, "to")) - from_token = StreamToken.from_string(from_token_string) + from_token = await StreamToken.from_string(self.store, from_token_string) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 51e395cc64..6779df952f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -77,6 +77,7 @@ class SyncRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self.store = hs.get_datastore() self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() @@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet): device_id=device_id, ) + since_token = None if since is not None: - since_token = StreamToken.from_string(since) - else: - since_token = None + since_token = await StreamToken.from_string(self.store, since) # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) @@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), + "next_batch": await sync_result.next_batch.to_string(self.store), } @staticmethod @@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet): result = { "timeline": { "events": serialized_timeline, - "prev_batch": room.timeline.prev_batch.to_string(), + "prev_batch": await room.timeline.prev_batch.to_string(self.store), "limited": room.timeline.limited, }, "state": {"events": serialized_state}, diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index d7a03cbf7d..ecfc6717b3 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): The set of state groups that are referenced by deleted events. """ + parsed_token = await RoomStreamToken.parse(self, token) + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, - token, + parsed_token, delete_local_events, ) - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - + def _purge_history_txn(self, txn, room_id, token, delete_local_events): # Tables that should be pruned: # event_auth # event_backward_extremities diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 0bdf846edf..fdda21d165 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Optional @@ -21,6 +20,7 @@ import attr from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string from synapse.http.site import SynapseRequest +from synapse.storage.databases.main import DataStore from synapse.types import StreamToken logger = logging.getLogger(__name__) @@ -39,8 +39,9 @@ class PaginationConfig: limit = attr.ib(type=Optional[int]) @classmethod - def from_request( + async def from_request( cls, + store: "DataStore", request: SynapseRequest, raise_invalid_params: bool = True, default_limit: Optional[int] = None, @@ -54,13 +55,13 @@ class PaginationConfig: if from_tok == "END": from_tok = None # For backwards compat. elif from_tok: - from_tok = StreamToken.from_string(from_tok) + from_tok = await StreamToken.from_string(store, from_tok) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: - to_tok = StreamToken.from_string(to_tok) + to_tok = await StreamToken.from_string(store, to_tok) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/synapse/types.py b/synapse/types.py index 02bcc197ec..bd271f9f16 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,17 @@ import re import string import sys from collections import namedtuple -from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + MutableMapping, + Optional, + Tuple, + Type, + TypeVar, +) import attr from signedjson.key import decode_verify_key_bytes @@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): from typing import Collection @@ -393,7 +406,7 @@ class RoomStreamToken: stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) @classmethod - def parse(cls, string: str) -> "RoomStreamToken": + async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -428,7 +441,7 @@ class RoomStreamToken: def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) - def __str__(self) -> str: + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) else: @@ -453,18 +466,32 @@ class StreamToken: START = None # type: StreamToken @classmethod - def from_string(cls, string): + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": try: keys = string.split(cls._SEPARATOR) while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") - return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:])) + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) except Exception: raise SynapseError(400, "Invalid Token") - def to_string(self): - return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)]) + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + str(self.groups_key), + ] + ) @property def room_stream_id(self): @@ -493,7 +520,7 @@ class StreamToken: return attr.evolve(self, **{key: new_value}) -StreamToken.START = StreamToken.from_string("s0_0") +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) @attr.s(slots=True, frozen=True) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a3287011e9..0d809d25d5 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -902,16 +902,18 @@ class RoomMessageListTestCase(RoomBase): # Send a first message in the room, which will be removed by the purge. first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] - first_token = str( - self.get_success(store.get_topological_token_for_event(first_event_id)) + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) ) + first_token_str = self.get_success(first_token.to_string(store)) # Send a second message in the room, which won't be removed, and which we'll # use as the marker to purge events before. second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] - second_token = str( - self.get_success(store.get_topological_token_for_event(second_event_id)) + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) ) + second_token_str = self.get_success(second_token.to_string(store)) # Send a third event in the room to ensure we don't fall under any edge case # due to our marker being the latest forward extremity in the room. @@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) @@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase): pagination_handler._purge_history( purge_id=purge_id, room_id=self.room_id, - token=second_token, + token=second_token_str, delete_local_events=True, ) ) @@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) @@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase): request, channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), + % ( + self.room_id, + first_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), ) self.render(request) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 723cd28933..cc1f3c53c5 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = str( - self.get_success(store.get_topological_token_for_event(last["event_id"])) + token = self.get_success( + store.get_topological_token_for_event(last["event_id"]) ) + token_str = self.get_success(token.to_string(self.hs.get_datastore())) # Purge everything before this topological token - self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) + self.get_success( + storage.purge_events.purge_history(self.room_id, token_str, True) + ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. -- cgit 1.5.1 From 4ff0201e6235b8b2efc5ce5a7dc3c479ea96df53 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 1 Oct 2020 08:09:18 -0400 Subject: Enable mypy checking for unreachable code and fix instances. (#8432) --- changelog.d/8432.misc | 1 + mypy.ini | 1 + synapse/config/tls.py | 18 +++++++++--------- synapse/federation/federation_server.py | 5 ++--- synapse/handlers/directory.py | 2 +- synapse/handlers/room.py | 2 -- synapse/handlers/room_member.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/http/server.py | 4 ++-- synapse/logging/_structured.py | 10 +--------- synapse/push/push_rule_evaluator.py | 4 ++-- synapse/replication/tcp/protocol.py | 10 ++++++---- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/censor_events.py | 6 +++--- synapse/storage/databases/main/events.py | 18 +++++------------- synapse/storage/databases/main/stream.py | 2 +- synapse/storage/util/id_generators.py | 2 +- 17 files changed, 38 insertions(+), 53 deletions(-) create mode 100644 changelog.d/8432.misc (limited to 'synapse/storage') diff --git a/changelog.d/8432.misc b/changelog.d/8432.misc new file mode 100644 index 0000000000..01fdad4caf --- /dev/null +++ b/changelog.d/8432.misc @@ -0,0 +1 @@ +Check for unreachable code with mypy. diff --git a/mypy.ini b/mypy.ini index 7986781432..c283f15b21 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,6 +6,7 @@ check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs +warn_unreachable = True files = synapse/api, synapse/appservice, diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 9ddb8b546b..ad37b93c02 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -18,7 +18,7 @@ import os import warnings from datetime import datetime from hashlib import sha256 -from typing import List +from typing import List, Optional from unpaddedbase64 import encode_base64 @@ -177,8 +177,8 @@ class TlsConfig(Config): "use_insecure_ssl_client_just_for_testing_do_not_use" ) - self.tls_certificate = None - self.tls_private_key = None + self.tls_certificate = None # type: Optional[crypto.X509] + self.tls_private_key = None # type: Optional[crypto.PKey] def is_disk_cert_valid(self, allow_self_signed=True): """ @@ -226,12 +226,12 @@ class TlsConfig(Config): days_remaining = (expires_on - now).days return days_remaining - def read_certificate_from_disk(self, require_cert_and_key): + def read_certificate_from_disk(self, require_cert_and_key: bool): """ Read the certificates and private key from disk. Args: - require_cert_and_key (bool): set to True to throw an error if the certificate + require_cert_and_key: set to True to throw an error if the certificate and key file are not given """ if require_cert_and_key: @@ -479,13 +479,13 @@ class TlsConfig(Config): } ) - def read_tls_certificate(self): + def read_tls_certificate(self) -> crypto.X509: """Reads the TLS certificate from the configured file, and returns it Also checks if it is self-signed, and warns if so Returns: - OpenSSL.crypto.X509: the certificate + The certificate """ cert_path = self.tls_certificate_file logger.info("Loading TLS certificate from %s", cert_path) @@ -504,11 +504,11 @@ class TlsConfig(Config): return cert - def read_tls_private_key(self): + def read_tls_private_key(self) -> crypto.PKey: """Reads the TLS private key from the configured file, and returns it Returns: - OpenSSL.crypto.PKey: the private key + The private key """ private_key_path = self.tls_private_key_file logger.info("Loading TLS key from %s", private_key_path) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 24329dd0e3..02f11e1209 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -22,7 +22,6 @@ from typing import ( Callable, Dict, List, - Match, Optional, Tuple, Union, @@ -825,14 +824,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: return False -def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: +def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool: if not isinstance(acl_entry, str): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False regex = glob_to_regex(acl_entry) - return regex.match(server_name) + return bool(regex.match(server_name)) class FederationHandlerRegistry: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 62aa9a2da8..6f15c68240 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -383,7 +383,7 @@ class DirectoryHandler(BaseHandler): """ creator = await self.store.get_room_alias_creator(alias.to_string()) - if creator is not None and creator == user_id: + if creator == user_id: return True # Resolve the alias to the corresponding room. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d5f7c78edf..f1a6699cd4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -962,8 +962,6 @@ class RoomCreationHandler(BaseHandler): try: random_string = stringutils.random_string(18) gen_room_id = RoomID(random_string, self.hs.hostname).to_string() - if isinstance(gen_room_id, bytes): - gen_room_id = gen_room_id.decode("utf-8") await self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8feba8c90a..5ec36f591d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -642,7 +642,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def send_membership_event( self, - requester: Requester, + requester: Optional[Requester], event: EventBase, context: EventContext, ratelimit: bool = True, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bfe2583002..260ec19b41 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -87,7 +87,7 @@ class SyncConfig: class TimelineBatch: prev_batch = attr.ib(type=StreamToken) events = attr.ib(type=List[EventBase]) - limited = attr.ib(bool) + limited = attr.ib(type=bool) def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used diff --git a/synapse/http/server.py b/synapse/http/server.py index 996a31a9ec..09ed74f6ce 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -257,7 +257,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return @@ -406,7 +406,7 @@ class JsonResource(DirectServeJsonResource): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 144506c8f2..0fc2ea609e 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import os.path import sys @@ -89,14 +88,7 @@ class LogContextObserver: context = current_context() # Copy the context information to the log event. - if context is not None: - context.copy_to_twisted_log_entry(event) - else: - # If there's no logging context, not even the root one, we might be - # starting up or it might be from non-Synapse code. Log it as if it - # came from the root logger. - event["request"] = None - event["scope"] = None + context.copy_to_twisted_log_entry(event) self.observer(event) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 709ace01e5..3a68ce636f 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,7 +16,7 @@ import logging import re -from typing import Any, Dict, List, Pattern, Union +from typing import Any, Dict, List, Optional, Pattern, Union from synapse.events import EventBase from synapse.types import UserID @@ -181,7 +181,7 @@ class PushRuleEvaluatorForEvent: return r.search(body) - def _get_value(self, dotted_key: str) -> str: + def _get_value(self, dotted_key: str) -> Optional[str]: return self._value_cache.get(dotted_key, None) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0b0d204e64..a509e599c2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -51,10 +51,11 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from twisted.internet import task from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 - self.time_we_closed = None # When we requested the connection be closed + # When we requested the connection be closed + self.time_we_closed = None # type: Optional[int] - self.received_ping = False # Have we reecived a ping from the other side + self.received_ping = False # Have we received a ping from the other side self.state = ConnectionStates.CONNECTING @@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. - self._send_ping_loop = None + self._send_ping_loop = None # type: Optional[task.LoopingCall] # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 31082bb16a..5b0900aa3c 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -738,7 +738,7 @@ def _make_state_cache_entry( # failing that, look for the closest match. prev_group = None - delta_ids = None + delta_ids = None # type: Optional[StateMap[str]] for old_group, old_state in state_groups_ids.items(): n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index f211ddbaf8..4bb2b9c28c 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.storage.databases.main.events import encode_json from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.frozenutils import frozendict_json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( + pruned_json = frozendict_json_encoder.encode( prune_event_dict(event.room_version, event.get_dict()) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 18def01f50..78e645592f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -52,16 +52,6 @@ event_counter = Counter( ) -def encode_json(json_object): - """ - Encode a Python object as JSON and return it in a Unicode string. - """ - out = frozendict_json_encoder.encode(json_object) - if isinstance(out, bytes): - out = out.decode("utf8") - return out - - _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @@ -743,7 +733,9 @@ class PersistEventsStore: logger.exception("") raise - metadata_json = encode_json(event.internal_metadata.get_dict()) + metadata_json = frozendict_json_encoder.encode( + event.internal_metadata.get_dict() + ) sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -797,10 +789,10 @@ class PersistEventsStore: { "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": encode_json( + "internal_metadata": frozendict_json_encoder.encode( event.internal_metadata.get_dict() ), - "json": encode_json(event_dict(event)), + "json": frozendict_json_encoder.encode(event_dict(event)), "format_version": event.format_version, } for event, _ in events_and_contexts diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 37249f1e3f..1d27439536 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -546,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_room_event_before_stream_ordering( self, room_id: str, stream_ordering: int - ) -> Tuple[int, int, str]: + ) -> Optional[Tuple[int, int, str]]: """Gets details of the first event in a room at or before a stream ordering Args: diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 02fbb656e8..ec356b2e4f 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -421,7 +421,7 @@ class MultiWriterIdGenerator: self._unfinished_ids.discard(next_id) self._finished_ids.add(next_id) - new_cur = None + new_cur = None # type: Optional[int] if self._unfinished_ids: # If there are unfinished IDs then the new position will be the -- cgit 1.5.1 From 695240d34a9dd1c34379ded1fbbbe42a1850549e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Oct 2020 12:22:19 +0100 Subject: Fix DB query on startup for negative streams. (#8447) For negative streams we have to negate the internal stream ID before querying the DB. The effect of this bug was to query far too many rows, slowing start up time, but we would correctly filter the results afterwards so there was no ill effect. --- changelog.d/8447.bugfix | 1 + synapse/storage/util/id_generators.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8447.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/8447.bugfix b/changelog.d/8447.bugfix new file mode 100644 index 0000000000..88edaf322e --- /dev/null +++ b/changelog.d/8447.bugfix @@ -0,0 +1 @@ +Fix DB query on startup for negative streams which caused long start up times. Introduced in #8374. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 02fbb656e8..48efbb5067 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -341,7 +341,7 @@ class MultiWriterIdGenerator: "cmp": "<=" if self._positive else ">=", } sql = self._db.engine.convert_param_style(sql) - cur.execute(sql, (min_stream_id,)) + cur.execute(sql, (min_stream_id * self._return_factor,)) self._persisted_upto_position = min_stream_id -- cgit 1.5.1 From 62894673e69f7beb0d0a748ad01c2e95c5fed106 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 2 Oct 2020 08:23:15 -0400 Subject: Allow background tasks to be run on a separate worker. (#8369) --- changelog.d/8369.feature | 1 + docs/sample_config.yaml | 5 + docs/workers.md | 17 ++ synapse/app/_base.py | 6 + synapse/app/admin_cmd.py | 1 + synapse/app/generic_worker.py | 4 + synapse/app/homeserver.py | 182 ------------------- synapse/app/phone_stats_home.py | 202 +++++++++++++++++++++ synapse/config/workers.py | 18 ++ synapse/handlers/auth.py | 2 +- synapse/handlers/stats.py | 2 +- synapse/server.py | 17 +- synapse/storage/databases/main/__init__.py | 191 ------------------- synapse/storage/databases/main/metrics.py | 195 ++++++++++++++++++++ .../storage/databases/main/monthly_active_users.py | 109 +++++------ synapse/storage/databases/main/room.py | 24 +-- synapse/storage/databases/main/ui_auth.py | 6 +- tests/test_phone_home.py | 2 +- tests/utils.py | 2 +- 19 files changed, 537 insertions(+), 449 deletions(-) create mode 100644 changelog.d/8369.feature create mode 100644 synapse/app/phone_stats_home.py (limited to 'synapse/storage') diff --git a/changelog.d/8369.feature b/changelog.d/8369.feature new file mode 100644 index 0000000000..542993110b --- /dev/null +++ b/changelog.d/8369.feature @@ -0,0 +1 @@ +Allow running background tasks in a separate worker process. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b2c1d7a737..7126ade2de 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2504,6 +2504,11 @@ opentracing: # events: worker1 # typing: worker1 +# The worker that is used to run background tasks (e.g. cleaning up expired +# data). If not provided this defaults to the main process. +# +#run_background_tasks_on: worker1 + # Configuration for Redis when using workers. This *must* be enabled when # using workers (unless using old style direct TCP configuration). diff --git a/docs/workers.md b/docs/workers.md index ad4d8ca9f2..84a9759e34 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -319,6 +319,23 @@ stream_writers: events: event_persister1 ``` +#### Background tasks + +There is also *experimental* support for moving background tasks to a separate +worker. Background tasks are run periodically or started via replication. Exactly +which tasks are configured to run depends on your Synapse configuration (e.g. if +stats is enabled). + +To enable this, the worker must have a `worker_name` and can be configured to run +background tasks. For example, to move background tasks to a dedicated worker, +the shared configuration would include: + +```yaml +run_background_tasks_on: background_worker +``` + +You might also wish to investigate the `update_user_directory` and +`media_instance_running_background_jobs` settings. ### `synapse.app.pusher` diff --git a/synapse/app/_base.py b/synapse/app/_base.py index fb476ddaf5..8bb0b142ca 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -28,6 +28,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory import synapse from synapse.app import check_bind_error +from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.server import ListenerConfig from synapse.crypto import context_factory from synapse.logging.context import PreserveLoggingContext @@ -274,6 +275,11 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): setup_sentry(hs) setup_sdnotify(hs) + # If background tasks are running on the main process, start collecting the + # phone home stats. + if hs.config.run_background_tasks: + start_phone_stats_home(hs) + # We now freeze all allocated objects in the hopes that (almost) # everything currently allocated are things that will be used for the # rest of time. Doing so means less work each GC (hopefully). diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 7d309b1bb0..f0d65d08d7 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -208,6 +208,7 @@ def start(config_options): # Explicitly disable background processes config.update_user_directory = False + config.run_background_tasks = False config.start_pushers = False config.send_federation = False diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index c38413c893..fc5188ce95 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -128,11 +128,13 @@ from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore +from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.databases.main.presence import UserPresenceState from synapse.storage.databases.main.search import SearchWorkerStore +from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -454,6 +456,7 @@ class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. UserDirectoryStore, + StatsStore, UIAuthWorkerStore, SlavedDeviceInboxStore, SlavedDeviceStore, @@ -476,6 +479,7 @@ class GenericWorkerSlavedStore( SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, + ServerMetricsStore, SearchWorkerStore, BaseSlavedStore, ): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index dff739e106..4ed4a2c253 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -17,14 +17,10 @@ import gc import logging -import math import os -import resource import sys from typing import Iterable -from prometheus_client import Gauge - from twisted.application import service from twisted.internet import defer, reactor from twisted.python.failure import Failure @@ -60,7 +56,6 @@ from synapse.http.server import ( from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.python_dependencies import check_requirements from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource @@ -334,20 +329,6 @@ class SynapseHomeServer(HomeServer): logger.warning("Unrecognized listener type: %s", listener.type) -# Gauges to expose monthly active user control metrics -current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") -current_mau_by_service_gauge = Gauge( - "synapse_admin_mau_current_mau_by_service", - "Current MAU by service", - ["app_service"], -) -max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") -registered_reserved_users_mau_gauge = Gauge( - "synapse_admin_mau:registered_reserved_users", - "Registered users with reserved threepids", -) - - def setup(config_options): """ Args: @@ -389,8 +370,6 @@ def setup(config_options): except UpgradeDatabaseException as e: quit_with_error("Failed to upgrade database: %s" % (e,)) - hs.setup_master() - async def do_acme() -> bool: """ Reprovision an ACME certificate, if it's required. @@ -486,92 +465,6 @@ class SynapseService(service.Service): return self._port.stopListening() -# Contains the list of processes we will be monitoring -# currently either 0 or 1 -_stats_process = [] - - -async def phone_stats_home(hs, stats, stats_process=_stats_process): - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - uptime = int(now - hs.start_time) - if uptime < 0: - uptime = 0 - - # - # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. - # - old = stats_process[0] - new = (now, resource.getrusage(resource.RUSAGE_SELF)) - stats_process[0] = new - - # Get RSS in bytes - stats["memory_rss"] = new[1].ru_maxrss - - # Get CPU time in % of a single core, not % of all cores - used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( - old[1].ru_utime + old[1].ru_stime - ) - if used_cpu_time == 0 or new[0] == old[0]: - stats["cpu_average"] = 0 - else: - stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) - - # - # General statistics - # - - stats["homeserver"] = hs.config.server_name - stats["server_context"] = hs.config.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = await hs.get_datastore().count_all_users() - - total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = await hs.get_datastore().count_daily_user_type() - for name, count in daily_user_type_results.items(): - stats["daily_user_type_" + name] = count - - room_count = await hs.get_datastore().get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = await hs.get_datastore().count_daily_users() - stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() - stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() - stats["daily_messages"] = await hs.get_datastore().count_daily_messages() - - r30_results = await hs.get_datastore().count_r30_users() - for name, count in r30_results.items(): - stats["r30_users_" + name] = count - - daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = hs.config.caches.global_factor - stats["event_cache_size"] = hs.config.caches.event_cache_size - - # - # Database version - # - - # This only reports info about the *main* database. - stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ - stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version - - logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) - try: - await hs.get_proxied_http_client().put_json( - hs.config.report_stats_endpoint, stats - ) - except Exception as e: - logger.warning("Error reporting stats: %s", e) - - def run(hs): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: @@ -597,81 +490,6 @@ def run(hs): ThreadPool._worker = profile(ThreadPool._worker) reactor.run = profile(reactor.run) - clock = hs.get_clock() - - stats = {} - - def performance_stats_init(): - _stats_process.clear() - _stats_process.append( - (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) - ) - - def start_phone_stats_home(): - return run_as_background_process( - "phone_stats_home", phone_stats_home, hs, stats - ) - - def generate_user_daily_visit_stats(): - return run_as_background_process( - "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits - ) - - # Rather than update on per session basis, batch up the requests. - # If you increase the loop period, the accuracy of user_daily_visits - # table will decrease - clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) - - # monthly active user limiting functionality - def reap_monthly_active_users(): - return run_as_background_process( - "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users - ) - - clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) - reap_monthly_active_users() - - async def generate_monthly_active_users(): - current_mau_count = 0 - current_mau_count_by_service = {} - reserved_users = () - store = hs.get_datastore() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - current_mau_count = await store.get_monthly_active_count() - current_mau_count_by_service = ( - await store.get_monthly_active_count_by_service() - ) - reserved_users = await store.get_registered_reserved_users() - current_mau_gauge.set(float(current_mau_count)) - - for app_service, count in current_mau_count_by_service.items(): - current_mau_by_service_gauge.labels(app_service).set(float(count)) - - registered_reserved_users_mau_gauge.set(float(len(reserved_users))) - max_mau_gauge.set(float(hs.config.max_mau_value)) - - def start_generate_monthly_active_users(): - return run_as_background_process( - "generate_monthly_active_users", generate_monthly_active_users - ) - - start_generate_monthly_active_users() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) - # End of monthly active user settings - - if hs.config.report_stats: - logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) - - # We need to defer this init for the cases that we daemonize - # otherwise the process ID we get is that of the non-daemon process - clock.call_later(0, performance_stats_init) - - # We wait 5 minutes to send the first set of stats as the server can - # be quite busy the first few minutes - clock.call_later(5 * 60, start_phone_stats_home) - _base.start_reactor( "synapse-homeserver", soft_file_limit=hs.config.soft_file_limit, diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py new file mode 100644 index 0000000000..2c8e14a8c0 --- /dev/null +++ b/synapse/app/phone_stats_home.py @@ -0,0 +1,202 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import resource +import sys + +from prometheus_client import Gauge + +from synapse.metrics.background_process_metrics import run_as_background_process + +logger = logging.getLogger("synapse.app.homeserver") + +# Contains the list of processes we will be monitoring +# currently either 0 or 1 +_stats_process = [] + +# Gauges to expose monthly active user control metrics +current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") +current_mau_by_service_gauge = Gauge( + "synapse_admin_mau_current_mau_by_service", + "Current MAU by service", + ["app_service"], +) +max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit") +registered_reserved_users_mau_gauge = Gauge( + "synapse_admin_mau:registered_reserved_users", + "Registered users with reserved threepids", +) + + +async def phone_stats_home(hs, stats, stats_process=_stats_process): + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 + + # + # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new + + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss + + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime + ) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # General statistics + # + + stats["homeserver"] = hs.config.server_name + stats["server_context"] = hs.config.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = await hs.get_datastore().count_all_users() + + total_nonbridged_users = await hs.get_datastore().count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = await hs.get_datastore().count_daily_user_type() + for name, count in daily_user_type_results.items(): + stats["daily_user_type_" + name] = count + + room_count = await hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = await hs.get_datastore().count_daily_users() + stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() + stats["daily_messages"] = await hs.get_datastore().count_daily_messages() + + r30_results = await hs.get_datastore().count_r30_users() + for name, count in r30_results.items(): + stats["r30_users_" + name] = count + + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size + + # + # Database version + # + + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version + + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + try: + await hs.get_proxied_http_client().put_json( + hs.config.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + +def start_phone_stats_home(hs): + """ + Start the background tasks which report phone home stats. + """ + clock = hs.get_clock() + + stats = {} + + def performance_stats_init(): + _stats_process.clear() + _stats_process.append( + (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) + ) + + def start_phone_stats_home(): + return run_as_background_process( + "phone_stats_home", phone_stats_home, hs, stats + ) + + def generate_user_daily_visit_stats(): + return run_as_background_process( + "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits + ) + + # Rather than update on per session basis, batch up the requests. + # If you increase the loop period, the accuracy of user_daily_visits + # table will decrease + clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + + # monthly active user limiting functionality + def reap_monthly_active_users(): + return run_as_background_process( + "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users + ) + + clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) + reap_monthly_active_users() + + async def generate_monthly_active_users(): + current_mau_count = 0 + current_mau_count_by_service = {} + reserved_users = () + store = hs.get_datastore() + if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + current_mau_count = await store.get_monthly_active_count() + current_mau_count_by_service = ( + await store.get_monthly_active_count_by_service() + ) + reserved_users = await store.get_registered_reserved_users() + current_mau_gauge.set(float(current_mau_count)) + + for app_service, count in current_mau_count_by_service.items(): + current_mau_by_service_gauge.labels(app_service).set(float(count)) + + registered_reserved_users_mau_gauge.set(float(len(reserved_users))) + max_mau_gauge.set(float(hs.config.max_mau_value)) + + def start_generate_monthly_active_users(): + return run_as_background_process( + "generate_monthly_active_users", generate_monthly_active_users + ) + + if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + start_generate_monthly_active_users() + clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) + # End of monthly active user settings + + if hs.config.report_stats: + logger.info("Scheduling stats reporting for 3 hour intervals") + clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) + + # We need to defer this init for the cases that we daemonize + # otherwise the process ID we get is that of the non-daemon process + clock.call_later(0, performance_stats_init) + + # We wait 5 minutes to send the first set of stats as the server can + # be quite busy the first few minutes + clock.call_later(5 * 60, start_phone_stats_home) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index f23e42cdf9..57ab097eba 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -132,6 +132,19 @@ class WorkerConfig(Config): self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + # Whether this worker should run background tasks or not. + # + # As a note for developers, the background tasks guarded by this should + # be able to run on only a single instance (meaning that they don't + # depend on any in-memory state of a particular worker). + # + # No effort is made to ensure only a single instance of these tasks is + # running. + background_tasks_instance = config.get("run_background_tasks_on") or "master" + self.run_background_tasks = ( + self.worker_name is None and background_tasks_instance == "master" + ) or self.worker_name == background_tasks_instance + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ ## Workers ## @@ -167,6 +180,11 @@ class WorkerConfig(Config): #stream_writers: # events: worker1 # typing: worker1 + + # The worker that is used to run background tasks (e.g. cleaning up expired + # data). If not provided this defaults to the main process. + # + #run_background_tasks_on: worker1 """ def read_arguments(self, args): diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 00eae92052..7c4b716b28 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -212,7 +212,7 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() # Expire old UI auth sessions after a period of time. - if hs.config.worker_app is None: + if hs.config.run_background_tasks: self._clock.looping_call( run_as_background_process, 5 * 60 * 1000, diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 249ffe2a55..dc62b21c06 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -49,7 +49,7 @@ class StatsHandler: # Guard to ensure we only process deltas one at a time self._is_processing = False - if hs.config.stats_enabled: + if self.stats_enabled and hs.config.run_background_tasks: self.notifier.add_replication_callback(self.notify_new_event) # We kick this off so that we don't have to wait for a change before diff --git a/synapse/server.py b/synapse/server.py index 5e3752c333..aa2273955c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -185,7 +185,10 @@ class HomeServer(metaclass=abc.ABCMeta): we are listening on to provide HTTP services. """ - REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] + REQUIRED_ON_BACKGROUND_TASK_STARTUP = [ + "auth", + "stats", + ] # This is overridden in derived application classes # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be @@ -251,14 +254,20 @@ class HomeServer(metaclass=abc.ABCMeta): self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") - def setup_master(self) -> None: + # Register background tasks required by this server. This must be done + # somewhat manually due to the background tasks not being registered + # unless handlers are instantiated. + if self.config.run_background_tasks: + self.setup_background_tasks() + + def setup_background_tasks(self) -> None: """ Some handlers have side effects on instantiation (like registering background updates). This function causes them to be fetched, and therefore instantiated, to run those side effects. """ - for i in self.REQUIRED_ON_MASTER_STARTUP: - getattr(self, "get_" + i)() + for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: + getattr(self, "get_" + i + "_handler")() def get_reactor(self) -> twisted.internet.base.ReactorBase: """ diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0cb12f4c61..f823d66709 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,9 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import calendar import logging -import time from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState @@ -268,9 +266,6 @@ class DataStore( self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - # Used in _generate_user_daily_visits to keep track of progress - self._last_user_visit_update = self._get_start_of_day() - def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() @@ -301,192 +296,6 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - async def count_daily_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 24 hours. - """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return await self.db_pool.runInteraction( - "count_daily_users", self._count_users, yesterday - ) - - async def count_monthly_users(self) -> int: - """ - Counts the number of users who used this homeserver in the last 30 days. - Note this method is intended for phonehome metrics only and is different - from the mau figure in synapse.storage.monthly_active_users which, - amongst other things, includes a 3 day grace period before a user counts. - """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return await self.db_pool.runInteraction( - "count_monthly_users", self._count_users, thirty_days_ago - ) - - def _count_users(self, txn, time_from): - """ - Returns number of users seen in the past time_from period - """ - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT user_id FROM user_ips - WHERE last_seen > ? - GROUP BY user_id - ) u - """ - txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() - return count - - async def count_r30_users(self) -> Dict[str, int]: - """ - Counts the number of 30 day retained users, defined as:- - * Users who have created their accounts more than 30 days ago - * Where last seen at most 30 days ago - * Where account creation and last_seen are > 30 days apart - - Returns: - A mapping of counts globally as well as broken out by platform. - """ - - def _count_r30_users(txn): - thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) - thirty_days_ago_in_secs = now - thirty_days_in_secs - - sql = """ - SELECT platform, COALESCE(count(*), 0) FROM ( - SELECT - users.name, platform, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen, - CASE - WHEN user_agent LIKE '%%Android%%' THEN 'android' - WHEN user_agent LIKE '%%iOS%%' THEN 'ios' - WHEN user_agent LIKE '%%Electron%%' THEN 'electron' - WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' - WHEN user_agent LIKE '%%Gecko%%' THEN 'web' - ELSE 'unknown' - END - AS platform - FROM user_ips - ) uip - ON users.name = uip.user_id - AND users.appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, platform, users.creation_ts - ) u GROUP BY platform - """ - - results = {} - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - for row in txn: - if row[0] == "unknown": - pass - results[row[0]] = row[1] - - sql = """ - SELECT COALESCE(count(*), 0) FROM ( - SELECT users.name, users.creation_ts * 1000, - MAX(uip.last_seen) - FROM users - INNER JOIN ( - SELECT - user_id, - last_seen - FROM user_ips - ) uip - ON users.name = uip.user_id - AND appservice_id is NULL - AND users.creation_ts < ? - AND uip.last_seen/1000 > ? - AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 - GROUP BY users.name, users.creation_ts - ) u - """ - - txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - - (count,) = txn.fetchone() - results["all"] = count - - return results - - return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) - - def _get_start_of_day(self): - """ - Returns millisecond unixtime for start of UTC day. - """ - now = time.gmtime() - today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) - return today_start * 1000 - - async def generate_user_daily_visits(self) -> None: - """ - Generates daily visit data for use in cohort/ retention analysis - """ - - def _generate_user_daily_visits(txn): - logger.info("Calling _generate_user_daily_visits") - today_start = self._get_start_of_day() - a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self.clock.time_msec() - - sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? - FROM user_ips AS u - LEFT JOIN ( - SELECT user_id, device_id, timestamp FROM user_daily_visits - WHERE timestamp = ? - ) udv - ON u.user_id = udv.user_id AND u.device_id=udv.device_id - INNER JOIN users ON users.name=u.user_id - WHERE last_seen > ? AND last_seen <= ? - AND udv.timestamp IS NULL AND users.is_guest=0 - AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id - """ - - # This means that the day has rolled over but there could still - # be entries from the previous day. There is an edge case - # where if the user logs in at 23:59 and overwrites their - # last_seen at 00:01 then they will not be counted in the - # previous day's stats - it is important that the query is run - # often to minimise this case. - if today_start > self._last_user_visit_update: - yesterday_start = today_start - a_day_in_milliseconds - txn.execute( - sql, - ( - yesterday_start, - yesterday_start, - self._last_user_visit_update, - today_start, - ), - ) - self._last_user_visit_update = today_start - - txn.execute( - sql, (today_start, today_start, self._last_user_visit_update, now) - ) - # Update _last_user_visit_update to now. The reason to do this - # rather just clamping to the beginning of the day is to limit - # the size of the join - meaning that the query can be run more - # frequently - self._last_user_visit_update = now - - await self.db_pool.runInteraction( - "generate_user_daily_visits", _generate_user_daily_visits - ) - async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 92099f95ce..2c5a4fdbf6 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -12,6 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import calendar +import logging +import time +from typing import Dict from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import run_as_background_process @@ -21,6 +25,8 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +logger = logging.getLogger(__name__) + # Collect metrics on the number of forward extremities that exist. _extremities_collecter = GaugeBucketCollector( "synapse_forward_extremities", @@ -60,6 +66,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + # Used in _generate_user_daily_visits to keep track of progress + self._last_user_visit_update = self._get_start_of_day() + async def _read_forward_extremities(self): def fetch(txn): txn.execute( @@ -137,3 +146,189 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return count return await self.db_pool.runInteraction("count_daily_active_rooms", _count) + + async def count_daily_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 24 hours. + """ + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + return await self.db_pool.runInteraction( + "count_daily_users", self._count_users, yesterday + ) + + async def count_monthly_users(self) -> int: + """ + Counts the number of users who used this homeserver in the last 30 days. + Note this method is intended for phonehome metrics only and is different + from the mau figure in synapse.storage.monthly_active_users which, + amongst other things, includes a 3 day grace period before a user counts. + """ + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + return await self.db_pool.runInteraction( + "count_monthly_users", self._count_users, thirty_days_ago + ) + + def _count_users(self, txn, time_from): + """ + Returns number of users seen in the past time_from period + """ + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ + txn.execute(sql, (time_from,)) + (count,) = txn.fetchone() + return count + + async def count_r30_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns: + A mapping of counts globally as well as broken out by platform. + """ + + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + for row in txn: + if row[0] == "unknown": + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) + + (count,) = txn.fetchone() + results["all"] = count + + return results + + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + + def _get_start_of_day(self): + """ + Returns millisecond unixtime for start of UTC day. + """ + now = time.gmtime() + today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) + return today_start * 1000 + + async def generate_user_daily_visits(self) -> None: + """ + Generates daily visit data for use in cohort/ retention analysis + """ + + def _generate_user_daily_visits(txn): + logger.info("Calling _generate_user_daily_visits") + today_start = self._get_start_of_day() + a_day_in_milliseconds = 24 * 60 * 60 * 1000 + now = self._clock.time_msec() + + sql = """ + INSERT INTO user_daily_visits (user_id, device_id, timestamp) + SELECT u.user_id, u.device_id, ? + FROM user_ips AS u + LEFT JOIN ( + SELECT user_id, device_id, timestamp FROM user_daily_visits + WHERE timestamp = ? + ) udv + ON u.user_id = udv.user_id AND u.device_id=udv.device_id + INNER JOIN users ON users.name=u.user_id + WHERE last_seen > ? AND last_seen <= ? + AND udv.timestamp IS NULL AND users.is_guest=0 + AND users.appservice_id IS NULL + GROUP BY u.user_id, u.device_id + """ + + # This means that the day has rolled over but there could still + # be entries from the previous day. There is an edge case + # where if the user logs in at 23:59 and overwrites their + # last_seen at 00:01 then they will not be counted in the + # previous day's stats - it is important that the query is run + # often to minimise this case. + if today_start > self._last_user_visit_update: + yesterday_start = today_start - a_day_in_milliseconds + txn.execute( + sql, + ( + yesterday_start, + yesterday_start, + self._last_user_visit_update, + today_start, + ), + ) + self._last_user_visit_update = today_start + + txn.execute( + sql, (today_start, today_start, self._last_user_visit_update, now) + ) + # Update _last_user_visit_update to now. The reason to do this + # rather just clamping to the beginning of the day is to limit + # the size of the join - meaning that the query can be run more + # frequently + self._last_user_visit_update = now + + await self.db_pool.runInteraction( + "generate_user_daily_visits", _generate_user_daily_visits + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e93aad33cd..b2127598ef 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -32,6 +32,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._max_mau_value = hs.config.max_mau_value + @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -124,60 +127,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): desc="user_last_seen_monthly_active", ) - -class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._mau_stats_only = hs.config.mau_stats_only - self._max_mau_value = hs.config.max_mau_value - - # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], - ) - - def _initialise_reserved_users(self, txn, threepids): - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn (cursor): - threepids (list[dict]): List of threepid dicts to reserve - """ - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting #6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - async def reap_monthly_active_users(self): """Cleans out monthly active user table to ensure that no stale entries exist. @@ -257,6 +206,58 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._mau_stats_only = hs.config.mau_stats_only + + # Do not add more reserved users than the total allowable number + # cur = LoggingTransaction( + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3c7630857f..c0f2af0785 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore): "count_public_rooms", _count_public_rooms_txn ) + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. + """ + + def f(txn): + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 + + return await self.db_pool.runInteraction("get_rooms", f) + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], @@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ - - def f(txn): - sql = "SELECT count(*) FROM rooms" - txn.execute(sql) - row = txn.fetchone() - return row[0] or 0 - - return await self.db_pool.runInteraction("get_rooms", f) - async def add_event_report( self, room_id: str, diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 3b9211a6d2..79b7ece330 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore): ) return [(row["user_agent"], row["ip"]) for row in rows] - -class UIAuthStore(UIAuthWorkerStore): async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore): iterable=session_ids, keyvalues={}, ) + + +class UIAuthStore(UIAuthWorkerStore): + pass diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py index 7657bddea5..e7aed092c2 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py @@ -17,7 +17,7 @@ import resource import mock -from synapse.app.homeserver import phone_stats_home +from synapse.app.phone_stats_home import phone_stats_home from tests.unittest import HomeserverTestCase diff --git a/tests/utils.py b/tests/utils.py index 4673872f88..7a927c7f74 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -276,7 +276,7 @@ def setup_test_homeserver( hs.setup() if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() + hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): database = hs.get_datastores().databases[0] -- cgit 1.5.1 From ec10bdd32bb52af73789f5f60b39135578a739b1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Oct 2020 15:09:31 +0100 Subject: Speed up unit tests when using PostgreSQL (#8450) --- changelog.d/8450.misc | 1 + synapse/storage/databases/main/events_worker.py | 13 ++++++++++++- tests/server.py | 4 ++++ 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8450.misc (limited to 'synapse/storage') diff --git a/changelog.d/8450.misc b/changelog.d/8450.misc new file mode 100644 index 0000000000..4e04c523ab --- /dev/null +++ b/changelog.d/8450.misc @@ -0,0 +1 @@ +Speed up unit tests when using PostgreSQL. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index f95679ebc4..723ced4ff0 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -74,6 +74,13 @@ class EventRedactBehaviour(Names): class EventsWorkerStore(SQLBaseStore): + # Whether to use dedicated DB threads for event fetching. This is only used + # if there are multiple DB threads available. When used will lock the DB + # thread for periods of time (so unit tests want to disable this when they + # run DB transactions on the main thread). See EVENT_QUEUE_* for more + # options controlling this. + USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True + def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -522,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore): if not event_list: single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): self._event_fetch_ongoing -= 1 return else: diff --git a/tests/server.py b/tests/server.py index b404ad4e2a..f7f5276b21 100644 --- a/tests/server.py +++ b/tests/server.py @@ -372,6 +372,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool.threadpool = ThreadPool(clock._reactor) pool.running = True + # We've just changed the Databases to run DB transactions on the same + # thread, so we need to disable the dedicated thread behaviour. + server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server -- cgit 1.5.1 From e3debf9682ed59b2972f236fe2982b6af0a9bb9a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Oct 2020 15:20:45 +0100 Subject: Add logging on startup/shutdown (#8448) This is so we can tell what is going on when things are taking a while to start up. The main change here is to ensure that transactions that are created during startup get correctly logged like normal transactions. --- changelog.d/8448.misc | 1 + scripts/synapse_port_db | 2 +- synapse/app/_base.py | 5 ++ synapse/storage/database.py | 89 ++++++++++++++++++---- synapse/storage/databases/__init__.py | 2 +- synapse/storage/databases/main/__init__.py | 1 - .../storage/databases/main/event_push_actions.py | 8 +- .../storage/databases/main/monthly_active_users.py | 1 - synapse/storage/databases/main/roommember.py | 13 +--- .../databases/main/schema/delta/20/pushers.py | 19 +++-- .../storage/databases/main/schema/delta/25/fts.py | 2 - .../storage/databases/main/schema/delta/27/ts.py | 2 - .../databases/main/schema/delta/30/as_users.py | 6 +- .../databases/main/schema/delta/31/pushers.py | 19 +++-- .../main/schema/delta/31/search_update.py | 2 - .../databases/main/schema/delta/33/event_fields.py | 2 - .../main/schema/delta/33/remote_media_ts.py | 5 +- .../schema/delta/56/unique_user_filter_index.py | 7 +- .../schema/delta/57/local_current_membership.py | 1 - synapse/storage/prepare_database.py | 33 +++----- synapse/storage/types.py | 6 ++ synapse/storage/util/id_generators.py | 8 +- synapse/storage/util/sequence.py | 15 +++- tests/storage/test_appservice.py | 14 ++-- tests/utils.py | 2 + 25 files changed, 152 insertions(+), 113 deletions(-) create mode 100644 changelog.d/8448.misc (limited to 'synapse/storage') diff --git a/changelog.d/8448.misc b/changelog.d/8448.misc new file mode 100644 index 0000000000..5ddda1803b --- /dev/null +++ b/changelog.d/8448.misc @@ -0,0 +1 @@ +Add SQL logging on queries that happen during startup. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index ae2887b7d2..7e12f5440c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -489,7 +489,7 @@ class Porter(object): hs = MockHomeserver(self.hs_config) - with make_conn(db_config, engine) as db_conn: + with make_conn(db_config, engine, "portdb") as db_conn: engine.check_database( db_conn, allow_outdated_version=allow_outdated_version ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 8bb0b142ca..f6f7b2bf42 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -272,6 +272,11 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]): hs.get_datastore().db_pool.start_profiling() hs.get_pusherpool().start() + # Log when we start the shut down process. + hs.get_reactor().addSystemEventTrigger( + "before", "shutdown", logger.info, "Shutting down..." + ) + setup_sentry(hs) setup_sdnotify(hs) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 79ec8f119d..0d9d9b7cc0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from typing import ( overload, ) +import attr from prometheus_client import Histogram from typing_extensions import Literal @@ -90,13 +91,17 @@ def make_pool( return adbapi.ConnectionPool( db_config.config["name"], cp_reactor=reactor, - cp_openfun=engine.on_new_connection, + cp_openfun=lambda conn: engine.on_new_connection( + LoggingDatabaseConnection(conn, engine, "on_new_connection") + ), **db_config.config.get("args", {}) ) def make_conn( - db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, + default_txn_name: str, ) -> Connection: """Make a new connection to the database and return it. @@ -109,11 +114,60 @@ def make_conn( for k, v in db_config.config.get("args", {}).items() if not k.startswith("cp_") } - db_conn = engine.module.connect(**db_params) + native_db_conn = engine.module.connect(**db_params) + db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) + engine.on_new_connection(db_conn) return db_conn +@attr.s(slots=True) +class LoggingDatabaseConnection: + """A wrapper around a database connection that returns `LoggingTransaction` + as its cursor class. + + This is mainly used on startup to ensure that queries get logged correctly + """ + + conn = attr.ib(type=Connection) + engine = attr.ib(type=BaseDatabaseEngine) + default_txn_name = attr.ib(type=str) + + def cursor( + self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + ) -> "LoggingTransaction": + if not txn_name: + txn_name = self.default_txn_name + + return LoggingTransaction( + self.conn.cursor(), + name=txn_name, + database_engine=self.engine, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, + ) + + def close(self) -> None: + self.conn.close() + + def commit(self) -> None: + self.conn.commit() + + def rollback(self, *args, **kwargs) -> None: + self.conn.rollback(*args, **kwargs) + + def __enter__(self) -> "Connection": + self.conn.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + return self.conn.__exit__(exc_type, exc_value, traceback) + + # Proxy through any unknown lookups to the DB conn class. + def __getattr__(self, name): + return getattr(self.conn, name) + + # The type of entry which goes on our after_callbacks and exception_callbacks lists. # # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so @@ -247,6 +301,12 @@ class LoggingTransaction: def close(self) -> None: self.txn.close() + def __enter__(self) -> "LoggingTransaction": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + class PerformanceCounters: def __init__(self): @@ -395,7 +455,7 @@ class DatabasePool: def new_transaction( self, - conn: Connection, + conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], @@ -418,12 +478,10 @@ class DatabasePool: i = 0 N = 5 while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.engine, - after_callbacks, - exception_callbacks, + cursor = conn.cursor( + txn_name=name, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, ) try: r = func(cursor, *args, **kwargs) @@ -584,7 +642,10 @@ class DatabasePool: logger.debug("Reconnecting closed database connection") conn.reconnect() - return func(conn, *args, **kwargs) + db_conn = LoggingDatabaseConnection( + conn, self.engine, "runWithConnection" + ) + return func(db_conn, *args, **kwargs) return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) @@ -1621,7 +1682,7 @@ class DatabasePool: def get_cache_dict( self, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, table: str, entity_column: str, stream_column: str, @@ -1642,9 +1703,7 @@ class DatabasePool: "limit": limit, } - sql = self.engine.convert_param_style(sql) - - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="get_cache_dict") txn.execute(sql, (int(max_value),)) cache = {row[0]: int(row[1]) for row in txn} diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index aa5d490624..0c24325011 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -46,7 +46,7 @@ class Databases: db_name = database_config.name engine = create_engine(database_config.config) - with make_conn(database_config, engine) as db_conn: + with make_conn(database_config, engine, "startup") as db_conn: logger.info("[database config %r]: Checking database server", db_name) engine.check_database(db_conn) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index f823d66709..9b16f45f3e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -284,7 +284,6 @@ class DataStore( " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) - sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 62f1738732..80f3b4d740 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union import attr from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -74,11 +74,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): self.stream_ordering_month_ago = None self.stream_ordering_day_ago = None - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - ) + cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") self._find_stream_orderings_for_times_txn(cur) cur.close() diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index b2127598ef..c66f558567 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -214,7 +214,6 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): self._mau_stats_only = hs.config.mau_stats_only # Do not add more reserved users than the total allowable number - # cur = LoggingTransaction( self.db_pool.new_transaction( db_conn, "initialise_mau_threepids", diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 86ffe2479e..bae1bd22d3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -21,12 +21,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - LoggingTransaction, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine @@ -60,10 +55,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): # background update still running? self._current_state_events_membership_up_to_date = False - txn = LoggingTransaction( - db_conn.cursor(), - name="_check_safe_current_state_events_membership_updated", - database_engine=self.database_engine, + txn = db_conn.cursor( + txn_name="_check_safe_current_state_events_membership_updated" ) self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py index 3edfcfd783..45b846e6a7 100644 --- a/synapse/storage/databases/main/schema/delta/20/pushers.py +++ b/synapse/storage/databases/main/schema/delta/20/pushers.py @@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs): row[8] = bytes(row[8]).decode("utf-8") row[11] = bytes(row[11]).decode("utf-8") cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_token, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_token, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py index ee675e71ff..21f57825d4 100644 --- a/synapse/storage/databases/main/schema/delta/25/fts.py +++ b/synapse/storage/databases/main/schema/delta/25/fts.py @@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py index b7972cfa8e..1c6058063f 100644 --- a/synapse/storage/databases/main/schema/delta/27/ts.py +++ b/synapse/storage/databases/main/schema/delta/27/ts.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py index b42c02710a..7f08fabe9f 100644 --- a/synapse/storage/databases/main/schema/delta/30/as_users.py +++ b/synapse/storage/databases/main/schema/delta/30/as_users.py @@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name IN (%s)" - % (",".join("?" for _ in chunk),) - ), + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" + % (",".join("?" for _ in chunk),), [as_id] + chunk, ) diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py index 9bb504aad5..5be81c806a 100644 --- a/synapse/storage/databases/main/schema/delta/31/pushers.py +++ b/synapse/storage/databases/main/schema/delta/31/pushers.py @@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs): row = list(row) row[12] = token_to_stream_ordering(row[12]) cur.execute( - database_engine.convert_param_style( - """ - INSERT into pushers2 ( - id, user_name, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, ts, lang, data, last_stream_ordering, last_success, - failing_since - ) values (%s)""" - % (",".join(["?" for _ in range(len(row))])) - ), + """ + INSERT into pushers2 ( + id, user_name, access_token, profile_tag, kind, + app_id, app_display_name, device_display_name, + pushkey, ts, lang, data, last_stream_ordering, last_success, + failing_since + ) values (%s) + """ + % (",".join(["?" for _ in range(len(row))])), row, ) count += 1 diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py index 63b757ade6..b84c844e3a 100644 --- a/synapse/storage/databases/main/schema/delta/31/search_update.py +++ b/synapse/storage/databases/main/schema/delta/31/search_update.py @@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_search_order", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py index a3e81eeac7..e928c66a8f 100644 --- a/synapse/storage/databases/main/schema/delta/33/event_fields.py +++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py @@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs): " VALUES (?, ?)" ) - sql = database_engine.convert_param_style(sql) - cur.execute(sql, ("event_fields_sender_url", progress_json)) diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py index a26057dfb6..ad875c733a 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py @@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - database_engine.convert_param_style( - "UPDATE remote_media_cache SET last_access_ts = ?" - ), - (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py index 1de8b54961..bb7296852a 100644 --- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py +++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py @@ -1,6 +1,8 @@ import logging +from io import StringIO from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream logger = logging.getLogger(__name__) @@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs): select_clause, ) - if isinstance(database_engine, PostgresEngine): - cur.execute(sql) - else: - cur.executescript(sql) + execute_statements_from_stream(cur, StringIO(sql)) diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py index 63b5acdcf7..44917f0a2e 100644 --- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py +++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py @@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): INNER JOIN room_memberships AS r USING (event_id) WHERE type = 'm.room.member' AND state_key LIKE ? """ - sql = database_engine.convert_param_style(sql) cur.execute(sql, ("%:" + config.server_name,)) cur.execute( diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 4957e77f4c..459754feab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import imp import logging import os @@ -24,9 +23,10 @@ from typing import Optional, TextIO import attr from synapse.config.homeserver import HomeServerConfig +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines.postgres import PostgresEngine -from synapse.storage.types import Connection, Cursor +from synapse.storage.types import Cursor from synapse.types import Collection logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = ( def prepare_database( - db_conn: Connection, + db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], databases: Collection[str] = ["main", "state"], @@ -89,7 +89,7 @@ def prepare_database( """ try: - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="prepare_database") # sqlite does not automatically start transactions for DDL / SELECT statements, # so we start one before running anything. This ensures that any upgrades @@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases): executescript(cur, entry.absolute_path) cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (max_current_ver, False), ) @@ -486,17 +484,13 @@ def _upgrade_existing_database( # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)" - ), + "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)", (v, relative_path), ) cur.execute("DELETE FROM schema_version") cur.execute( - database_engine.convert_param_style( - "INSERT INTO schema_version (version, upgraded) VALUES (?,?)" - ), + "INSERT INTO schema_version (version, upgraded) VALUES (?,?)", (v, True), ) @@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) schemas to be applied """ cur.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_module_schemas WHERE module_name = ?" - ), - (modname,), + "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), ) applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: @@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) # Mark as done. cur.execute( - database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)" - ), + "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)", (modname, name), ) @@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine): if current_version: txn.execute( - database_engine.convert_param_style( - "SELECT file FROM applied_schema_deltas WHERE version >= ?" - ), + "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), ) applied_deltas = [d for d, in txn] diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 2d2b560e74..970bb1b9da 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -61,3 +61,9 @@ class Connection(Protocol): def rollback(self, *args, **kwargs) -> None: ... + + def __enter__(self) -> "Connection": + ... + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index c92cd4a6ba..51f680d05d 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -54,7 +54,7 @@ def _load_current_id(db_conn, table, column, step=1): """ # debug logging for https://github.com/matrix-org/synapse/issues/7968 logger.info("initialising stream generator for %s(%s)", table, column) - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_id") if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: @@ -269,7 +269,7 @@ class MultiWriterIdGenerator: def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ): - cur = db_conn.cursor() + cur = db_conn.cursor(txn_name="_load_current_ids") # Load the current positions of all writers for the stream. if self._writers: @@ -283,15 +283,12 @@ class MultiWriterIdGenerator: stream_name = ? AND instance_name != ALL(?) """ - sql = self._db.engine.convert_param_style(sql) cur.execute(sql, (self._stream_name, self._writers)) sql = """ SELECT instance_name, stream_id FROM stream_positions WHERE stream_name = ? """ - sql = self._db.engine.convert_param_style(sql) - cur.execute(sql, (self._stream_name,)) self._current_positions = { @@ -340,7 +337,6 @@ class MultiWriterIdGenerator: "instance": instance_column, "cmp": "<=" if self._positive else ">=", } - sql = self._db.engine.convert_param_style(sql) cur.execute(sql, (min_stream_id * self._return_factor,)) self._persisted_upto_position = min_stream_id diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 2dd95e2709..ff2d038ad2 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -17,6 +17,7 @@ import logging import threading from typing import Callable, List, Optional +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import ( BaseDatabaseEngine, IncorrectDatabaseSetup, @@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta): @abc.abstractmethod def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, ): """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. @@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator): return [i for (i,) in txn] def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: LoggingDatabaseConnection, + table: str, + id_column: str, + positive: bool = True, ): - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="sequence.check_consistency") # First we get the current max ID from the table. table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % { diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 46f94914ff..c905a38930 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): # must be done after inserts database = hs.get_datastores().databases[0] self.store = ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, make_conn(database._database_config, database.engine, "test"), hs ) def tearDown(self): @@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): db_config = hs.config.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine), hs + database, make_conn(db_config, self.engine, "test"), hs ) def _add_service(self, url, as_token, id): @@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, make_conn(database._database_config, database.engine, "test"), hs ) @defer.inlineCallbacks @@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, + make_conn(database._database_config, database.engine, "test"), + hs, ) e = cm.exception @@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine), hs + database, + make_conn(database._database_config, database.engine, "test"), + hs, ) e = cm.exception diff --git a/tests/utils.py b/tests/utils.py index 7a927c7f74..af563ffe0f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -38,6 +38,7 @@ from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database from synapse.util.ratelimitutils import FederationRateLimiter @@ -88,6 +89,7 @@ def setupdb(): host=POSTGRES_HOST, password=POSTGRES_PASSWORD, ) + db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") prepare_database(db_conn, db_engine, None) db_conn.close() -- cgit 1.5.1 From c5251c6fbd2722d54d33e02021f286053e611efc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 5 Oct 2020 09:28:05 -0400 Subject: Do not assume that account data is of the correct form. (#8454) This fixes a bug where `m.ignored_user_list` was assumed to be a dict, leading to odd behavior for users who set it to something else. --- changelog.d/8454.bugfix | 1 + synapse/api/constants.py | 5 +++++ synapse/handlers/room_member.py | 6 +++--- synapse/handlers/sync.py | 19 +++++++++++-------- synapse/storage/databases/main/account_data.py | 9 +++++++-- synapse/visibility.py | 15 +++++++-------- 6 files changed, 34 insertions(+), 21 deletions(-) create mode 100644 changelog.d/8454.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/8454.bugfix b/changelog.d/8454.bugfix new file mode 100644 index 0000000000..c06d490b6f --- /dev/null +++ b/changelog.d/8454.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug where invalid ignored users in account data could break clients. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 46013cde15..592abd844b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -155,3 +155,8 @@ class EventContentFields: class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" DEFAULT = MEGOLM_V1_AES_SHA2 + + +class AccountDataTypes: + DIRECT = "m.direct" + IGNORED_USER_LIST = "m.ignored_user_list" diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5ec36f591d..567a14bd0a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 from synapse import types -from synapse.api.constants import MAX_DEPTH, EventTypes, Membership +from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -247,7 +247,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): user_account_data, _ = await self.store.get_account_data_for_user(user_id) # Copy direct message state if applicable - direct_rooms = user_account_data.get("m.direct", {}) + direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {}) # Check which key this room is under if isinstance(direct_rooms, dict): @@ -258,7 +258,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Save back to user's m.direct account data await self.store.add_account_data_for_user( - user_id, "m.direct", direct_rooms + user_id, AccountDataTypes.DIRECT, direct_rooms ) break diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 260ec19b41..a998e6b7f6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup import attr from prometheus_client import Counter -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase from synapse.logging.context import current_context @@ -1378,13 +1378,16 @@ class SyncHandler: return set(), set(), set(), set() ignored_account_data = await self.store.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id=user_id + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id ) + # If there is ignored users account data and it matches the proper type, + # then use it. + ignored_users = frozenset() # type: FrozenSet[str] if ignored_account_data: - ignored_users = ignored_account_data.get("ignored_users", {}).keys() - else: - ignored_users = frozenset() + ignored_users_data = ignored_account_data.get("ignored_users", {}) + if isinstance(ignored_users_data, dict): + ignored_users = frozenset(ignored_users_data.keys()) if since_token: room_changes = await self._get_rooms_changed( @@ -1478,7 +1481,7 @@ class SyncHandler: return False async def _get_rooms_changed( - self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: """Gets the the changes that have happened since the last sync. """ @@ -1690,7 +1693,7 @@ class SyncHandler: return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) async def _get_all_rooms( - self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str] + self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1764,7 +1767,7 @@ class SyncHandler: async def _generate_room_entry( self, sync_result_builder: "SyncResultBuilder", - ignored_users: Set[str], + ignored_users: FrozenSet[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index ef81d73573..49ee23470d 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -18,6 +18,7 @@ import abc import logging from typing import Dict, List, Optional, Tuple +from synapse.api.constants import AccountDataTypes from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator @@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext ) -> bool: ignored_account_data = await self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", + AccountDataTypes.IGNORED_USER_LIST, ignorer_user_id, on_invalidate=cache_context.invalidate, ) if not ignored_account_data: return False - return ignored_user_id in ignored_account_data.get("ignored_users", {}) + try: + return ignored_user_id in ignored_account_data.get("ignored_users", {}) + except TypeError: + # The type of the ignored_users field is invalid. + return False class AccountDataStore(AccountDataWorkerStore): diff --git a/synapse/visibility.py b/synapse/visibility.py index e3da7744d2..527365498e 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,7 +16,7 @@ import logging import operator -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.events.utils import prune_event from synapse.storage import Storage from synapse.storage.state import StateFilter @@ -77,15 +77,14 @@ async def filter_events_for_client( ) ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - "m.ignored_user_list", user_id + AccountDataTypes.IGNORED_USER_LIST, user_id ) - # FIXME: This will explode if people upload something incorrect. - ignore_list = frozenset( - ignore_dict_content.get("ignored_users", {}).keys() - if ignore_dict_content - else [] - ) + ignore_list = frozenset() + if ignore_dict_content: + ignored_users_dict = ignore_dict_content.get("ignored_users", {}) + if isinstance(ignored_users_dict, dict): + ignore_list = frozenset(ignored_users_dict.keys()) erased_senders = await storage.main.are_users_erased((e.sender for e in events)) -- cgit 1.5.1 From f31f8e63198cfe46af48d788dbb294aba9155e5a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 5 Oct 2020 14:43:14 +0100 Subject: Remove stream ordering from Metadata dict (#8452) There's no need for it to be in the dict as well as the events table. Instead, we store it in a separate attribute in the EventInternalMetadata object, and populate that on load. This means that we can rely on it being correctly populated for any event which has been persited to the database. --- changelog.d/8452.misc | 1 + synapse/events/__init__.py | 6 +++-- synapse/events/utils.py | 5 +++++ synapse/federation/sender/__init__.py | 2 ++ synapse/federation/sender/per_destination_queue.py | 2 ++ synapse/handlers/federation.py | 3 +++ synapse/handlers/message.py | 4 +++- synapse/handlers/room_member.py | 13 ++++++----- synapse/rest/admin/__init__.py | 5 ++++- synapse/storage/databases/main/events.py | 4 ++++ synapse/storage/databases/main/events_worker.py | 26 +++++++++++++--------- synapse/storage/databases/main/stream.py | 13 ----------- synapse/storage/persist_events.py | 2 ++ 13 files changed, 53 insertions(+), 33 deletions(-) create mode 100644 changelog.d/8452.misc (limited to 'synapse/storage') diff --git a/changelog.d/8452.misc b/changelog.d/8452.misc new file mode 100644 index 0000000000..8288d91c78 --- /dev/null +++ b/changelog.d/8452.misc @@ -0,0 +1 @@ +Remove redundant databae loads of stream_ordering for events we already have. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index dc49df0812..7a51d0a22f 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -97,13 +97,16 @@ class DefaultDictProperty(DictProperty): class _EventInternalMetadata: - __slots__ = ["_dict"] + __slots__ = ["_dict", "stream_ordering"] def __init__(self, internal_metadata_dict: JsonDict): # we have to copy the dict, because it turns out that the same dict is # reused. TODO: fix that self._dict = dict(internal_metadata_dict) + # the stream ordering of this event. None, until it has been persisted. + self.stream_ordering = None # type: Optional[int] + outlier = DictProperty("outlier") # type: bool out_of_band_membership = DictProperty("out_of_band_membership") # type: bool send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str @@ -113,7 +116,6 @@ class _EventInternalMetadata: redacted = DictProperty("redacted") # type: bool txn_id = DictProperty("txn_id") # type: str token_id = DictProperty("token_id") # type: str - stream_ordering = DictProperty("stream_ordering") # type: int # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 32c73d3413..355cbe05f1 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -49,6 +49,11 @@ def prune_event(event: EventBase) -> EventBase: pruned_event_dict, event.room_version, event.internal_metadata.get_dict() ) + # copy the internal fields + pruned_event.internal_metadata.stream_ordering = ( + event.internal_metadata.stream_ordering + ) + # Mark the event as redacted pruned_event.internal_metadata.redacted = True diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 8bb17b3a05..e33b29a42c 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -297,6 +297,8 @@ class FederationSender: sent_pdus_destination_dist_total.inc(len(destinations)) sent_pdus_destination_dist_count.inc() + assert pdu.internal_metadata.stream_ordering + # track the fact that we have a PDU for these destinations, # to allow us to perform catch-up later on if the remote is unreachable # for a while. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index bc99af3fdd..db8e456fe8 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -158,6 +158,7 @@ class PerDestinationQueue: # yet know if we have anything to catch up (None) self._pending_pdus.append(pdu) else: + assert pdu.internal_metadata.stream_ordering self._catchup_last_skipped = pdu.internal_metadata.stream_ordering self.attempt_new_transaction() @@ -361,6 +362,7 @@ class PerDestinationQueue: last_successful_stream_ordering = ( final_pdu.internal_metadata.stream_ordering ) + assert last_successful_stream_ordering await self._store.set_destination_last_successful_stream_ordering( self._destination, last_successful_stream_ordering ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1a8144405a..5ac2fc5656 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -3008,6 +3008,9 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return + # the event has been persisted so it should have a stream ordering. + assert event.internal_metadata.stream_ordering + event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ee271e85e5..00513fbf37 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -682,7 +682,9 @@ class EventCreationHandler: event.event_id, prev_event.event_id, ) - return await self.store.get_stream_id_for_event(prev_event.event_id) + # we know it was persisted, so must have a stream ordering + assert prev_event.internal_metadata.stream_ordering + return prev_event.internal_metadata.stream_ordering return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 567a14bd0a..13b749b7cb 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -194,8 +194,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - _, stream_id = await self.store.get_event_ordering(duplicate.event_id) - return duplicate.event_id, stream_id + # we know it was persisted, so must have a stream ordering. + assert duplicate.internal_metadata.stream_ordering + return duplicate.event_id, duplicate.internal_metadata.stream_ordering prev_state_ids = await context.get_prev_state_ids() @@ -441,12 +442,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - _, stream_id = await self.store.get_event_ordering( - old_state.event_id - ) + # duplicate event. + # we know it was persisted, so must have a stream ordering. + assert old_state.internal_metadata.stream_ordering return ( old_state.event_id, - stream_id, + old_state.internal_metadata.stream_ordering, ) if old_membership in ["ban", "leave"] and action == "kick": diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 57cac22252..789431ef25 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -57,6 +57,7 @@ from synapse.rest.admin.users import ( UsersRestServletV2, WhoisRestServlet, ) +from synapse.types import RoomStreamToken from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -109,7 +110,9 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - room_token = await self.store.get_topological_token_for_event(event_id) + room_token = RoomStreamToken( + event.depth, event.internal_metadata.stream_ordering + ) token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 78e645592f..b4abd961b9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -331,6 +331,10 @@ class PersistEventsStore: min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + # stream orderings should have been assigned by now + assert min_stream_order + assert max_stream_order + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremeties, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 723ced4ff0..b7ed8ca6ab 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -723,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) + original_ev.internal_metadata.stream_ordering = row["stream_ordering"] event_map[event_id] = original_ev @@ -790,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore): * event_id (str) + * stream_ordering (int): stream ordering for this event + * json (str): json-encoded event structure * internal_metadata (str): json-encoded internal metadata dict @@ -822,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore): sql = """\ SELECT e.event_id, - e.internal_metadata, - e.json, - e.format_version, + e.stream_ordering, + ej.internal_metadata, + ej.json, + ej.format_version, r.room_version, rej.reason - FROM event_json as e - LEFT JOIN rooms r USING (room_id) + FROM events AS e + JOIN event_json AS ej USING (event_id) + LEFT JOIN rooms r ON r.room_id = e.room_id LEFT JOIN rejections as rej USING (event_id) WHERE """ @@ -842,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore): event_id = row[0] event_dict[event_id] = { "event_id": event_id, - "internal_metadata": row[1], - "json": row[2], - "format_version": row[3], - "room_version_id": row[4], - "rejected_reason": row[5], + "stream_ordering": row[1], + "internal_metadata": row[2], + "json": row[3], + "format_version": row[4], + "room_version_id": row[5], + "rejected_reason": row[6], "redactions": [], } diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 1d27439536..a94bec1ac5 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -589,19 +589,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): ) return "t%d-%d" % (topo, token) - async def get_stream_id_for_event(self, event_id: str) -> int: - """The stream ID for an event - Args: - event_id: The id of the event to look up a stream token for. - Raises: - StoreError if the event wasn't in the database. - Returns: - A stream ID. - """ - return await self.db_pool.runInteraction( - "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id, - ) - def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, allow_none=False, ) -> int: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 72939f3984..4d2d88d1f0 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -248,6 +248,8 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) event_stream_id = event.internal_metadata.stream_ordering + # stream ordering should have been assigned by now + assert event_stream_id pos = PersistedEventPosition(self._instance_name, event_stream_id) return pos, self.main_store.get_room_max_token() -- cgit 1.5.1 From 3cd78bbe9e208d2e93ccebee5d3586ee5f5a5d31 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 6 Oct 2020 13:26:29 -0400 Subject: Add support for MSC2732: olm fallback keys (#8312) --- changelog.d/8312.feature | 1 + scripts/synapse_port_db | 1 + synapse/handlers/e2e_keys.py | 16 ++++ synapse/handlers/sync.py | 8 ++ synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/databases/main/end_to_end_keys.py | 100 ++++++++++++++++++++- .../databases/main/schema/delta/58/11fallback.sql | 24 +++++ tests/handlers/test_e2e_keys.py | 65 ++++++++++++++ 8 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8312.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/11fallback.sql (limited to 'synapse/storage') diff --git a/changelog.d/8312.feature b/changelog.d/8312.feature new file mode 100644 index 0000000000..222a1b032a --- /dev/null +++ b/changelog.d/8312.feature @@ -0,0 +1 @@ +Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)). \ No newline at end of file diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 7e12f5440c..2d0b59ab53 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = { "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], "users": ["shadow_banned"], + "e2e_fallback_keys_json": ["used"], } diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index dd40fd1299..611742ae72 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -496,6 +496,22 @@ class E2eKeysHandler: log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) + fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None) + if fallback_keys and isinstance(fallback_keys, dict): + log_kv( + { + "message": "Updating fallback_keys for device.", + "user_id": user_id, + "device_id": device_id, + } + ) + await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys) + elif fallback_keys: + log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"}) + else: + log_kv( + {"message": "Did not update fallback_keys", "reason": "no keys given"} + ) # the device should have been registered already, but it may have been # deleted due to a race with a DELETE request. Or we may be using an diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a998e6b7f6..dd1f90e359 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -201,6 +201,8 @@ class SyncResult: device_lists: List of user_ids whose devices have changed device_one_time_keys_count: Dict of algorithm to count for one time keys for this device + device_unused_fallback_key_types: List of key types that have an unused fallback + key groups: Group updates, if any """ @@ -213,6 +215,7 @@ class SyncResult: to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) device_one_time_keys_count = attr.ib(type=JsonDict) + device_unused_fallback_key_types = attr.ib(type=List[str]) groups = attr.ib(type=Optional[GroupsSyncResult]) def __bool__(self) -> bool: @@ -1014,10 +1017,14 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict + unused_fallback_key_types = [] # type: List[str] if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) + unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( + user_id, device_id + ) logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) @@ -1041,6 +1048,7 @@ class SyncHandler: device_lists=device_lists, groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, + device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 6779df952f..2b84eb89c0 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -236,6 +236,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, "next_batch": await sync_result.next_batch.to_string(self.store), } diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 22e1ed15d0..8c97f2af5c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def set_e2e_fallback_keys( + self, user_id: str, device_id: str, fallback_keys: JsonDict + ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ + # fallback_keys will usually only have one item in it, so using a for + # loop (as opposed to calling simple_upsert_many_txn) won't be too bad + # FIXME: make sure that only one key per algorithm is uploaded + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + await self.db_pool.simple_upsert( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + values={ + "key_id": key_id, + "key_json": json_encoder.encode(fallback_key), + "used": False, + }, + desc="set_e2e_fallback_key", + ) + + @cached(max_entries=10000) + async def get_e2e_unused_fallback_key_types( + self, user_id: str, device_id: str + ) -> List[str]: + """Returns the fallback key types that have an unused key. + + Args: + user_id: the user whose keys are being queried + device_id: the device whose keys are being queried + + Returns: + a list of key types + """ + return await self.db_pool.simple_select_onecol( + "e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, + retcol="algorithm", + desc="get_e2e_unused_fallback_key_types", + ) + async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Optional[dict]: @@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) + fallback_sql = ( + "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) result = {} delete = [] + used_fallbacks = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn: + otk_row = txn.fetchone() + if otk_row is not None: + key_id, key_json = otk_row device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) + else: + # no one-time key available, so see if there's a fallback + # key + txn.execute(fallback_sql, (user_id, device_id, algorithm)) + fallback_row = txn.fetchone() + if fallback_row is not None: + key_id, key_json, used = fallback_row + device_result[algorithm + ":" + key_id] = key_json + if not used: + used_fallbacks.append( + (user_id, device_id, algorithm, key_id) + ) + + # drop any one-time keys that were claimed sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + # mark fallback keys as used + for user_id, device_id, algorithm, key_id in used_fallbacks: + self.db_pool.simple_update_txn( + txn, + "e2e_fallback_keys_json", + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, + {"used": True}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) + return result return await self.db_pool.runInteraction( @@ -754,6 +844,14 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + ) await self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql new file mode 100644 index 0000000000..4ed981dbf8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 366dcfb670..4e9e3dcbc2 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -171,6 +171,71 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, ) + @defer.inlineCallbacks + def test_fallback_key(self): + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + fallback_key = {"alg1:k1": "key1"} + otk = {"alg1:k2": "key2"} + + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"org.matrix.msc2732.fallback_keys": fallback_key}, + ) + ) + + # claiming an OTK when no OTKs are available should return the fallback + # key + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + # claiming an OTK again should return the same fallback key + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + + # if the user uploads a one-time key, the next claim should fetch the + # one-time key, and then go back to the fallback + yield defer.ensureDeferred( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": otk} + ) + ) + + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, + ) + + res = yield defer.ensureDeferred( + self.handler.claim_one_time_keys( + {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None + ) + ) + self.assertEqual( + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, + ) + @defer.inlineCallbacks def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" -- cgit 1.5.1 From 4cb44a158549e83d42061b02a8b704e7d5873b21 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 7 Oct 2020 08:00:17 -0400 Subject: Add support for MSC2697: Dehydrated devices (#8380) This allows a user to store an offline device on the server and then restore it at a subsequent login. --- changelog.d/8380.feature | 1 + synapse/handlers/device.py | 84 ++++++++++++- synapse/rest/client/v2_alpha/devices.py | 134 +++++++++++++++++++++ synapse/rest/client/v2_alpha/keys.py | 37 +++--- synapse/storage/databases/main/devices.py | 78 +++++++++++- synapse/storage/databases/main/end_to_end_keys.py | 7 +- synapse/storage/databases/main/registration.py | 32 ++++- .../main/schema/delta/58/11dehydration.sql | 20 +++ tests/handlers/test_device.py | 82 +++++++++++++ 9 files changed, 454 insertions(+), 21 deletions(-) create mode 100644 changelog.d/8380.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/11dehydration.sql (limited to 'synapse/storage') diff --git a/changelog.d/8380.feature b/changelog.d/8380.feature new file mode 100644 index 0000000000..05ccea19dc --- /dev/null +++ b/changelog.d/8380.feature @@ -0,0 +1 @@ +Add support for device dehydration ([MSC2697](https://github.com/matrix-org/matrix-doc/pull/2697)). diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b9d9098104..e883ed1e37 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -29,6 +29,7 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( + JsonDict, StreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -505,6 +506,85 @@ class DeviceHandler(DeviceWorkerHandler): # receive device updates. Mark this in DB. await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + async def store_dehydrated_device( + self, + user_id: str, + device_data: JsonDict, + initial_device_display_name: Optional[str] = None, + ) -> str: + """Store a dehydrated device for a user. If the user had a previous + dehydrated device, it is removed. + + Args: + user_id: the user that we are storing the device for + device_data: the dehydrated device information + initial_device_display_name: The display name to use for the device + Returns: + device id of the dehydrated device + """ + device_id = await self.check_device_registered( + user_id, None, initial_device_display_name, + ) + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data + ) + if old_device_id is not None: + await self.delete_device(user_id, old_device_id) + return device_id + + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + + async def rehydrate_device( + self, user_id: str, access_token: str, device_id: str + ) -> dict: + """Process a rehydration request from the user. + + Args: + user_id: the user who is rehydrating the device + access_token: the access token used for the request + device_id: the ID of the device that will be rehydrated + Returns: + a dict containing {"success": True} + """ + success = await self.store.remove_dehydrated_device(user_id, device_id) + + if not success: + raise errors.NotFoundError() + + # If the dehydrated device was successfully deleted (the device ID + # matched the stored dehydrated device), then modify the access + # token to use the dehydrated device's ID and copy the old device + # display name to the dehydrated device, and destroy the old device + # ID + old_device_id = await self.store.set_device_for_access_token( + access_token, device_id + ) + old_device = await self.store.get_device(user_id, old_device_id) + await self.store.update_device(user_id, device_id, old_device["display_name"]) + # can't call self.delete_device because that will clobber the + # access token so call the storage layer directly + await self.store.delete_device(user_id, old_device_id) + await self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=old_device_id + ) + + # tell everyone that the old device is gone and that the dehydrated + # device has a new display name + await self.notify_device_update(user_id, [old_device_id, device_id]) + + return {"success": True} + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 7e174de692..af117cb27c 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +22,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from ._base import client_patterns, interactive_auth_handler @@ -151,7 +153,139 @@ class DeviceRestServlet(RestServlet): return 200, {} +class DehydratedDeviceServlet(RestServlet): + """Retrieve or store a dehydrated device. + + GET /org.matrix.msc2697.v2/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + PUT /org.matrix.msc2697/dehydrated_device + Content-Type: application/json + + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_GET(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + result = {"device_id": device_id, "device_data": device_data} + return (200, result) + else: + raise errors.NotFoundError("No dehydrated device available") + + async def on_PUT(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + requester = await self.auth.get_user_by_req(request) + + if "device_data" not in submission: + raise errors.SynapseError( + 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_data"], dict): + raise errors.SynapseError( + 400, + "device_data must be an object", + errcode=errors.Codes.INVALID_PARAM, + ) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission["device_data"], + submission.get("initial_device_display_name", None), + ) + return 200, {"device_id": device_id} + + +class ClaimDehydratedDeviceServlet(RestServlet): + """Claim a dehydrated device. + + POST /org.matrix.msc2697.v2/dehydrated_device/claim + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "success": true, + } + + """ + + PATTERNS = client_patterns( + "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + + submission = parse_json_object_from_request(request) + + if "device_id" not in submission: + raise errors.SynapseError( + 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_id"], str): + raise errors.SynapseError( + 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM, + ) + + result = await self.device_handler.rehydrate_device( + requester.user.to_string(), + self.auth.get_access_token_from_request(request), + submission["device_id"], + ) + + return (200, result) + + def register_servlets(hs, http_server): DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) + DehydratedDeviceServlet(hs).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 55c4606569..b91996c738 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -67,6 +68,7 @@ class KeyUploadServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") async def on_POST(self, request, device_id): @@ -75,23 +77,28 @@ class KeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) if device_id is not None: - # passing the device_id here is deprecated; however, we allow it - # for now for compatibility with older clients. + # Providing the device_id should only be done for setting keys + # for dehydrated devices; however, we allow it for any device for + # compatibility with older clients. if requester.device_id is not None and device_id != requester.device_id: - set_tag("error", True) - log_kv( - { - "message": "Client uploading keys for a different device", - "logged_in_id": requester.device_id, - "key_being_uploaded": device_id, - } - ) - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, + dehydrated_device = await self.device_handler.get_dehydrated_device( + user_id ) + if dehydrated_device is not None and device_id != dehydrated_device[0]: + set_tag("error", True) + log_kv( + { + "message": "Client uploading keys for a different device", + "logged_in_id": requester.device_id, + "key_being_uploaded": device_id, + } + ) + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fdf394c612..317d6cde95 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ from synapse.storage.database import ( make_tuple_comparison_clause, ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key -from synapse.util import json_encoder +from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -698,6 +698,80 @@ class DeviceWorkerStore(SQLBaseStore): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + # FIXME: make sure device ID still exists in devices table + row = await self.db_pool.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return ( + (row["device_id"], json_decoder.decode(row["device_data"])) if row else None + ) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + self.db_pool.simple_upsert_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + values={"device_id": device_id, "device_data": device_data}, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: JsonDict + ) -> Optional[str]: + """Store a dehydrated device for a user. + + Args: + user_id: the user that we are storing the device for + device_id: the ID of the dehydrated device + device_data: the dehydrated device information + Returns: + device id of the user's previous dehydrated device, if any + """ + return await self.db_pool.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, + device_id, + json_encoder.encode(device_data), + ) + + async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: + """Remove a dehydrated device. + + Args: + user_id: the user that the dehydrated device belongs to + device_id: the ID of the dehydrated device + """ + count = await self.db_pool.simple_delete( + "dehydrated_devices", + {"user_id": user_id, "device_id": device_id}, + desc="remove_dehydrated_device", + ) + return count >= 1 + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 8c97f2af5c..359dc6e968 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -844,6 +844,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) self.db_pool.simple_delete_txn( txn, table="e2e_fallback_keys_json", diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index a83df7759d..16ba545740 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019,2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -964,6 +964,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, "access_tokens", {"token": token}, "device_id" + ) + + self.db_pool.simple_update_txn( + txn, "access_tokens", {"token": token}, {"device_id": device_id} + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) + + return old_device_id + + async def set_device_for_access_token(self, token: str, device_id: str) -> str: + """Sets the device ID associated with an access token. + + Args: + token: The access token to modify. + device_id: The new device ID. + Returns: + The old device ID associated with the access token. + """ + + return await self.db_pool.runInteraction( + "set_device_for_access_token", + self._set_device_for_access_token_txn, + token, + device_id, + ) + async def register_user( self, user_id: str, diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql new file mode 100644 index 0000000000..7851a0a825 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql @@ -0,0 +1,20 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL -- JSON-encoded client-defined data +); diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 969d44c787..4512c51311 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -224,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) ) self.reactor.advance(1000) + + +class DehydrationTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + self.handler = hs.get_device_handler() + self.registration = hs.get_registration_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + return hs + + def test_dehydrate_and_rehydrate_device(self): + user_id = "@boris:dehydration" + + self.get_success(self.store.register_user(user_id, "foobar")) + + # First check if we can store and fetch a dehydrated device + stored_dehydrated_device_id = self.get_success( + self.handler.store_dehydrated_device( + user_id=user_id, + device_data={"device_data": {"foo": "bar"}}, + initial_device_display_name="dehydrated device", + ) + ) + + retrieved_device_id, device_data = self.get_success( + self.handler.get_dehydrated_device(user_id=user_id) + ) + + self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) + self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) + + # Create a new login for the user and dehydrated the device + device_id, access_token = self.get_success( + self.registration.register_device( + user_id=user_id, device_id=None, initial_display_name="new device", + ) + ) + + # Trying to claim a nonexistent device should throw an error + self.get_failure( + self.handler.rehydrate_device( + user_id=user_id, + access_token=access_token, + device_id="not the right device ID", + ), + synapse.api.errors.NotFoundError, + ) + + # dehydrating the right devices should succeed and change our device ID + # to the dehydrated device's ID + res = self.get_success( + self.handler.rehydrate_device( + user_id=user_id, + access_token=access_token, + device_id=retrieved_device_id, + ) + ) + + self.assertEqual(res, {"success": True}) + + # make sure that our device ID has changed + user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) + + self.assertEqual(user_info["device_id"], retrieved_device_id) + + # make sure the device has the display name that was set from the login + res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) + + self.assertEqual(res["display_name"], "new device") + + # make sure that the device ID that we were initially assigned no longer exists + self.get_failure( + self.handler.get_device(user_id, device_id), + synapse.api.errors.NotFoundError, + ) + + # make sure that there's no device available for dehydrating now + ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + + self.assertIsNone(ret) -- cgit 1.5.1 From b460a088c647a6d3ea0e5a9f4f80d86bb9e303b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 7 Oct 2020 08:58:21 -0400 Subject: Add typing information to the device handler. (#8407) --- changelog.d/8407.misc | 1 + mypy.ini | 1 + synapse/handlers/device.py | 89 +++++++++++++++++++------------ synapse/storage/databases/main/devices.py | 6 +-- 4 files changed, 59 insertions(+), 38 deletions(-) create mode 100644 changelog.d/8407.misc (limited to 'synapse/storage') diff --git a/changelog.d/8407.misc b/changelog.d/8407.misc new file mode 100644 index 0000000000..d37002d75b --- /dev/null +++ b/changelog.d/8407.misc @@ -0,0 +1 @@ +Add typing information to the device handler. diff --git a/mypy.ini b/mypy.ini index e84ad04e41..a7ffb81ef1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,7 @@ files = synapse/federation, synapse/handlers/auth.py, synapse/handlers/cas_handler.py, + synapse/handlers/device.py, synapse/handlers/directory.py, synapse/handlers/events.py, synapse/handlers/federation.py, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index e883ed1e37..debb1b4f29 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -29,8 +29,10 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( + Collection, JsonDict, StreamToken, + UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -42,13 +44,16 @@ from synapse.util.retryutils import NotRetryingDestination from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) MAX_DEVICE_DISPLAY_NAME_LEN = 100 class DeviceWorkerHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs @@ -106,7 +111,9 @@ class DeviceWorkerHandler(BaseHandler): @trace @measure_func("device.get_user_ids_changed") - async def get_user_ids_changed(self, user_id: str, from_token: StreamToken): + async def get_user_ids_changed( + self, user_id: str, from_token: StreamToken + ) -> JsonDict: """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. """ @@ -222,8 +229,8 @@ class DeviceWorkerHandler(BaseHandler): possibly_joined = possibly_changed & users_who_share_room possibly_left = (possibly_changed | possibly_left) - users_who_share_room else: - possibly_joined = [] - possibly_left = [] + possibly_joined = set() + possibly_left = set() result = {"changed": list(possibly_joined), "left": list(possibly_left)} @@ -231,7 +238,7 @@ class DeviceWorkerHandler(BaseHandler): return result - async def on_federation_query_user_devices(self, user_id): + async def on_federation_query_user_devices(self, user_id: str) -> JsonDict: stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( user_id ) @@ -250,7 +257,7 @@ class DeviceWorkerHandler(BaseHandler): class DeviceHandler(DeviceWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_sender = hs.get_federation_sender() @@ -265,7 +272,7 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - def _check_device_name_length(self, name: str): + def _check_device_name_length(self, name: Optional[str]): """ Checks whether a device name is longer than the maximum allowed length. @@ -284,8 +291,11 @@ class DeviceHandler(DeviceWorkerHandler): ) async def check_device_registered( - self, user_id, device_id, initial_device_display_name=None - ): + self, + user_id: str, + device_id: Optional[str], + initial_device_display_name: Optional[str] = None, + ) -> str: """ If the given device has not been registered, register it with the supplied display name. @@ -293,12 +303,11 @@ class DeviceHandler(DeviceWorkerHandler): If no device_id is supplied, we make one up. Args: - user_id (str): @user:id - device_id (str | None): device id supplied by client - initial_device_display_name (str | None): device display name from - client + user_id: @user:id + device_id: device id supplied by client + initial_device_display_name: device display name from client Returns: - str: device id (generated if none was supplied) + device id (generated if none was supplied) """ self._check_device_name_length(initial_device_display_name) @@ -317,15 +326,15 @@ class DeviceHandler(DeviceWorkerHandler): # times in case of a clash. attempts = 0 while attempts < 5: - device_id = stringutils.random_string(10).upper() + new_device_id = stringutils.random_string(10).upper() new_device = await self.store.store_device( user_id=user_id, - device_id=device_id, + device_id=new_device_id, initial_device_display_name=initial_device_display_name, ) if new_device: - await self.notify_device_update(user_id, [device_id]) - return device_id + await self.notify_device_update(user_id, [new_device_id]) + return new_device_id attempts += 1 raise errors.StoreError(500, "Couldn't generate a device ID.") @@ -434,7 +443,9 @@ class DeviceHandler(DeviceWorkerHandler): @trace @measure_func("notify_device_update") - async def notify_device_update(self, user_id, device_ids): + async def notify_device_update( + self, user_id: str, device_ids: Collection[str] + ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. """ @@ -446,7 +457,7 @@ class DeviceHandler(DeviceWorkerHandler): user_id ) - hosts = set() + hosts = set() # type: Set[str] if self.hs.is_mine_id(user_id): hosts.update(get_domain_from_id(u) for u in users_who_share_room) hosts.discard(self.server_name) @@ -498,7 +509,7 @@ class DeviceHandler(DeviceWorkerHandler): self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) - async def user_left_room(self, user, room_id): + async def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() room_ids = await self.store.get_rooms_for_user(user_id) if not room_ids: @@ -586,7 +597,9 @@ class DeviceHandler(DeviceWorkerHandler): return {"success": True} -def _update_device_from_client_ips(device, client_ips): +def _update_device_from_client_ips( + device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] +) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -594,7 +607,7 @@ def _update_device_from_client_ips(device, client_ips): class DeviceListUpdater: "Handles incoming device list updates from federation and updates the DB" - def __init__(self, hs, device_handler): + def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -603,7 +616,9 @@ class DeviceListUpdater: self._remote_edu_linearizer = Linearizer(name="remote_device_list") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = ( + {} + ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -626,7 +641,9 @@ class DeviceListUpdater: ) @trace - async def incoming_device_list_update(self, origin, edu_content): + async def incoming_device_list_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming device list update from federation. Responsible for parsing the EDU and adding to pending updates list. """ @@ -687,7 +704,7 @@ class DeviceListUpdater: await self._handle_device_updates(user_id) @measure_func("_incoming_device_list_update") - async def _handle_device_updates(self, user_id): + async def _handle_device_updates(self, user_id: str) -> None: "Actually handle pending updates." with (await self._remote_edu_linearizer.queue(user_id)): @@ -735,7 +752,9 @@ class DeviceListUpdater: stream_id for _, stream_id, _, _ in pending_updates ) - async def _need_to_do_resync(self, user_id, updates): + async def _need_to_do_resync( + self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]] + ) -> bool: """Given a list of updates for a user figure out if we need to do a full resync, or whether we have enough data that we can just apply the delta. """ @@ -766,7 +785,7 @@ class DeviceListUpdater: return False @trace - async def _maybe_retry_device_resync(self): + async def _maybe_retry_device_resync(self) -> None: """Retry to resync device lists that are out of sync, except if another retry is in progress. """ @@ -809,7 +828,7 @@ class DeviceListUpdater: async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """Fetches all devices for a user and updates the device cache with them. Args: @@ -833,7 +852,7 @@ class DeviceListUpdater: # it later. await self.store.mark_remote_user_device_cache_as_stale(user_id) - return + return None except (RequestSendFailed, HttpResponseException) as e: logger.warning( "Failed to handle device list update for %s: %s", user_id, e, @@ -850,12 +869,12 @@ class DeviceListUpdater: # next time we get a device list update for this user_id. # This makes it more likely that the device lists will # eventually become consistent. - return + return None except FederationDeniedError as e: set_tag("error", True) log_kv({"reason": "FederationDeniedError"}) logger.info(e) - return + return None except Exception as e: set_tag("error", True) log_kv( @@ -868,7 +887,7 @@ class DeviceListUpdater: # it later. await self.store.mark_remote_user_device_cache_as_stale(user_id) - return + return None log_kv({"result": result}) stream_id = result["stream_id"] devices = result["devices"] @@ -929,7 +948,7 @@ class DeviceListUpdater: user_id: str, master_key: Optional[Dict[str, Any]], self_signing_key: Optional[Dict[str, Any]], - ) -> list: + ) -> List[str]: """Process the given new master and self-signing key for the given remote user. Args: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 317d6cde95..2d0a6408b5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -911,7 +911,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: str + self, user_id: str, device_id: str, initial_device_display_name: Optional[str] ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1029,7 +1029,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def update_remote_device_list_cache_entry( - self, user_id: str, device_id: str, content: JsonDict, stream_id: int + self, user_id: str, device_id: str, content: JsonDict, stream_id: str ) -> None: """Updates a single device in the cache of a remote user's devicelist. @@ -1057,7 +1057,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_id: str, content: JsonDict, - stream_id: int, + stream_id: str, ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( -- cgit 1.5.1 From 52a50e8686ec9af6c629004171748f41eae09f73 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 7 Oct 2020 15:15:33 +0100 Subject: Use vector clocks for room stream tokens. (#8439) Currently when using multiple event persisters we (in the worst case) don't tell clients about events until all event persisters have persisted new events after the original event. This is a suboptimal, especially if one of the event persisters goes down. To handle this, we encode the position of each event persister in the room tokens so that we can send events to clients immediately. To reduce the size of the token we do two things: 1. We create a unique immutable persistent mapping between instance names and a generated small integer ID, which we can encode in the tokens instead of the instance name; and 2. We encode the "persisted upto position" of the room token and then only explicitly include instances that have positions strictly greater than that. The new tokens look something like: `m3478~1.3488~2.3489`, where the first number is the min position, and the subsequent `-` separated pairs are the instance ID to positions map. (We use `.` and `~` as separators as they're URL safe and not already used by `StreamToken`). --- changelog.d/8439.misc | 1 + .../schema/delta/58/19instance_map.sql.postgres | 25 ++ synapse/storage/databases/main/stream.py | 280 ++++++++++++++++++--- synapse/types.py | 116 ++++++++- 4 files changed, 380 insertions(+), 42 deletions(-) create mode 100644 changelog.d/8439.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres (limited to 'synapse/storage') diff --git a/changelog.d/8439.misc b/changelog.d/8439.misc new file mode 100644 index 0000000000..237cb3b311 --- /dev/null +++ b/changelog.d/8439.misc @@ -0,0 +1 @@ +Allow events to be sent to clients sooner when using sharded event persisters. diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres new file mode 100644 index 0000000000..841186b826 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A unique and immutable mapping between instance name and an integer ID. This +-- lets us refer to instances via a small ID in e.g. stream tokens, without +-- having to encode the full name. +CREATE TABLE IF NOT EXISTS instance_map ( + instance_id SERIAL PRIMARY KEY, + instance_name TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name); diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a94bec1ac5..e3b9ff5ca6 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -53,7 +53,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import Collection, PersistedEventPosition, RoomStreamToken +from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache if TYPE_CHECKING: @@ -208,6 +210,55 @@ def _make_generic_sql_bound( ) +def _filter_results( + lower_token: Optional[RoomStreamToken], + upper_token: Optional[RoomStreamToken], + instance_name: str, + topological_ordering: int, + stream_ordering: int, +) -> bool: + """Returns True if the event persisted by the given instance at the given + topological/stream_ordering falls between the two tokens (taking a None + token to mean unbounded). + + Used to filter results from fetching events in the DB against the given + tokens. This is necessary to handle the case where the tokens include + position maps, which we handle by fetching more than necessary from the DB + and then filtering (rather than attempting to construct a complicated SQL + query). + """ + + event_historical_tuple = ( + topological_ordering, + stream_ordering, + ) + + if lower_token: + if lower_token.topological is not None: + # If these are historical tokens we compare the `(topological, stream)` + # tuples. + if event_historical_tuple <= lower_token.as_historical_tuple(): + return False + + else: + # If these are live tokens we compare the stream ordering against the + # writers stream position. + if stream_ordering <= lower_token.get_stream_pos_for_instance( + instance_name + ): + return False + + if upper_token: + if upper_token.topological is not None: + if upper_token.as_historical_tuple() < event_historical_tuple: + return False + else: + if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering: + return False + + return True + + def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create @@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): raise NotImplementedError() def get_room_max_token(self) -> RoomStreamToken: - return RoomStreamToken(None, self.get_room_max_stream_ordering()) + """Get a `RoomStreamToken` that marks the current maximum persisted + position of the events stream. Useful to get a token that represents + "now". + + The token returned is a "live" token that may have an instance_map + component. + """ + + min_pos = self._stream_id_gen.get_current_token() + + positions = {} + if isinstance(self._stream_id_gen, MultiWriterIdGenerator): + # The `min_pos` is the minimum position that we know all instances + # have finished persisting to, so we only care about instances whose + # positions are ahead of that. (Instance positions can be behind the + # min position as there are times we can work out that the minimum + # position is ahead of the naive minimum across all current + # positions. See MultiWriterIdGenerator for details) + positions = { + i: p + for i, p in self._stream_id_gen.get_positions().items() + if p > min_pos + } + + return RoomStreamToken(None, min_pos, positions) async def get_room_events_stream_for_rooms( self, @@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): if from_key == to_key: return [], from_key - from_id = from_key.stream - to_id = to_key.stream - - has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) + has_changed = self._events_stream_cache.has_entity_changed( + room_id, from_key.stream + ) if not has_changed: return [], from_key def f(txn): - sql = ( - "SELECT event_id, stream_ordering FROM events WHERE" - " room_id = ?" - " AND not outlier" - " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) - txn.execute(sql, (room_id, from_id, to_id, limit)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT event_id, instance_name, topological_ordering, stream_ordering + FROM events + WHERE + room_id = ? + AND not outlier + AND stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering %s LIMIT ? + """ % ( + order, + ) + txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ][:limit] return rows rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) @@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): [r.event_id for r in rows], get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._set_before_and_after(ret, rows, topo_order=False) if order.lower() == "desc": ret.reverse() @@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): async def get_membership_changes_for_user( self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken ) -> List[EventBase]: - from_id = from_key.stream - to_id = to_key.stream - if from_key == to_key: return [] - if from_id: + if from_key: has_changed = self._membership_stream_cache.has_entity_changed( - user_id, int(from_id) + user_id, int(from_key.stream) ) if not has_changed: return [] def f(txn): - sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e," - " room_memberships AS m" - " WHERE e.event_id = m.event_id" - " AND m.user_id = ?" - " AND e.stream_ordering > ? AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - ) - txn.execute(sql, (user_id, from_id, to_id)) - - rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + sql = """ + SELECT m.event_id, instance_name, topological_ordering, stream_ordering + FROM events AS e, room_memberships AS m + WHERE e.event_id = m.event_id + AND m.user_id = ? + AND e.stream_ordering > ? AND e.stream_ordering <= ? + ORDER BY e.stream_ordering ASC + """ + txn.execute(sql, (user_id, min_from_id, max_to_id,)) + + rows = [ + _EventDictReturn(event_id, None, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + from_key, + to_key, + instance_name, + topological_ordering, + stream_ordering, + ) + ] return rows @@ -966,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): else: order = "ASC" + # The bounds for the stream tokens are complicated by the fact + # that we need to handle the instance_map part of the tokens. We do this + # by fetching all events between the min stream token and the maximum + # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and + # then filtering the results. + if from_token.topological is not None: + from_bound = ( + from_token.as_historical_tuple() + ) # type: Tuple[Optional[int], int] + elif direction == "b": + from_bound = ( + None, + from_token.get_max_stream_pos(), + ) + else: + from_bound = ( + None, + from_token.stream, + ) + + to_bound = None # type: Optional[Tuple[Optional[int], int]] + if to_token: + if to_token.topological is not None: + to_bound = to_token.as_historical_tuple() + elif direction == "b": + to_bound = ( + None, + to_token.stream, + ) + else: + to_bound = ( + None, + to_token.get_max_stream_pos(), + ) + bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token.as_tuple(), - to_token=to_token.as_tuple() if to_token else None, + from_token=from_bound, + to_token=to_bound, engine=self.database_engine, ) @@ -980,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): bounds += " AND " + filter_clause args.extend(filter_args) - args.append(int(limit)) + # We fetch more events as we'll filter the result set + args.append(int(limit) * 2) select_keywords = "SELECT" join_clause = "" @@ -1002,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): select_keywords += "DISTINCT" sql = """ - %(select_keywords)s event_id, topological_ordering, stream_ordering + %(select_keywords)s + event_id, instance_name, + topological_ordering, stream_ordering FROM events %(join_clause)s WHERE outlier = ? AND room_id = ? AND %(bounds)s @@ -1017,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): txn.execute(sql, args) - rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn] + # Filter the result set. + rows = [ + _EventDictReturn(event_id, topological_ordering, stream_ordering) + for event_id, instance_name, topological_ordering, stream_ordering in txn + if _filter_results( + lower_token=to_token if direction == "b" else from_token, + upper_token=from_token if direction == "b" else to_token, + instance_name=instance_name, + topological_ordering=topological_ordering, + stream_ordering=stream_ordering, + ) + ][:limit] if rows: topo = rows[-1].topological_ordering @@ -1082,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): return (events, token) + @cached() + async def get_id_for_instance(self, instance_name: str) -> int: + """Get a unique, immutable ID that corresponds to the given Synapse worker instance. + """ + + def _get_id_for_instance_txn(txn): + instance_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + allow_none=True, + ) + if instance_id is not None: + return instance_id + + # If we don't have an entry upsert one. + # + # We could do this before the first check, and rely on the cache for + # efficiency, but each UPSERT causes the next ID to increment which + # can quickly bloat the size of the generated IDs for new instances. + self.db_pool.simple_upsert_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + values={}, + ) + + return self.db_pool.simple_select_one_onecol_txn( + txn, + table="instance_map", + keyvalues={"instance_name": instance_name}, + retcol="instance_id", + ) + + return await self.db_pool.runInteraction( + "get_id_for_instance", _get_id_for_instance_txn + ) + + @cached() + async def get_name_from_instance_id(self, instance_id: int) -> str: + """Get the instance name from an ID previously returned by + `get_id_for_instance`. + """ + + return await self.db_pool.simple_select_one_onecol( + table="instance_map", + keyvalues={"instance_id": instance_id}, + retcol="instance_name", + desc="get_name_from_instance_id", + ) + class StreamStore(StreamWorkerStore): def get_room_max_stream_ordering(self) -> int: diff --git a/synapse/types.py b/synapse/types.py index bd271f9f16..5bde67cc07 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -22,6 +22,7 @@ from typing import ( TYPE_CHECKING, Any, Dict, + Iterable, Mapping, MutableMapping, Optional, @@ -43,7 +44,7 @@ if TYPE_CHECKING: if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Container, Iterable, Sized + from typing import Container, Sized T_co = TypeVar("T_co", covariant=True) @@ -375,7 +376,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False): return username.decode("ascii") -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, cmp=False) class RoomStreamToken: """Tokens are positions between events. The token "s1" comes after event 1. @@ -397,6 +398,31 @@ class RoomStreamToken: event it comes after. Historic tokens start with a "t" followed by the "topological_ordering" id of the event it comes after, followed by "-", followed by the "stream_ordering" id of the event it comes after. + + There is also a third mode for live tokens where the token starts with "m", + which is sometimes used when using sharded event persisters. In this case + the events stream is considered to be a set of streams (one for each writer) + and the token encodes the vector clock of positions of each writer in their + respective streams. + + The format of the token in such case is an initial integer min position, + followed by the mapping of instance ID to position separated by '.' and '~': + + m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... + + The `min_pos` corresponds to the minimum position all writers have persisted + up to, and then only writers that are ahead of that position need to be + encoded. An example token is: + + m56~2.58~3.59 + + Which corresponds to a set of three (or more writers) where instances 2 and + 3 (these are instance IDs that can be looked up in the DB to fetch the more + commonly used instance names) are at positions 58 and 59 respectively, and + all other instances are at position 56. + + Note: The `RoomStreamToken` cannot have both a topological part and an + instance map. """ topological = attr.ib( @@ -405,6 +431,25 @@ class RoomStreamToken: ) stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) + instance_map = attr.ib( + type=Dict[str, int], + factory=dict, + validator=attr.validators.deep_mapping( + key_validator=attr.validators.instance_of(str), + value_validator=attr.validators.instance_of(int), + mapping_validator=attr.validators.instance_of(dict), + ), + ) + + def __attrs_post_init__(self): + """Validates that both `topological` and `instance_map` aren't set. + """ + + if self.instance_map and self.topological: + raise ValueError( + "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." + ) + @classmethod async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": try: @@ -413,6 +458,20 @@ class RoomStreamToken: if string[0] == "t": parts = string[1:].split("-", 1) return cls(topological=int(parts[0]), stream=int(parts[1])) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) + instance_map[instance_name] = pos + + return cls(topological=None, stream=stream, instance_map=instance_map,) except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -436,14 +495,61 @@ class RoomStreamToken: max_stream = max(self.stream, other.stream) - return RoomStreamToken(None, max_stream) + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return RoomStreamToken(None, max_stream, instance_map) + + def as_historical_tuple(self) -> Tuple[int, int]: + """Returns a tuple of `(topological, stream)` for historical tokens. + + Raises if not an historical token (i.e. doesn't have a topological part). + """ + if self.topological is None: + raise Exception( + "Cannot call `RoomStreamToken.as_historical_tuple` on live token" + ) - def as_tuple(self) -> Tuple[Optional[int], int]: return (self.topological, self.stream) + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position that the given writer was at at this token. + + This only makes sense for "live" tokens that may have a vector clock + component, and so asserts that this is a "live" token. + """ + assert self.topological is None + + # If we don't have an entry for the instance we can assume that it was + # at `self.stream`. + return self.instance_map.get(instance_name, self.stream) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + async def to_string(self, store: "DataStore") -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) + elif self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + instance_id = await store.get_id_for_instance(name) + entries.append("{}.{}".format(instance_id, pos)) + + encoded_map = "~".join(entries) + return "m{}~{}".format(self.stream, encoded_map) else: return "s%d" % (self.stream,) @@ -535,7 +641,7 @@ class PersistedEventPosition: stream = attr.ib(type=int) def persisted_after(self, token: RoomStreamToken) -> bool: - return token.stream < self.stream + return token.get_stream_pos_for_instance(self.instance_name) < self.stream def to_room_stream_token(self) -> RoomStreamToken: """Converts the position to a room stream token such that events -- cgit 1.5.1 From ae5b2a72c09d67311c9830f5a6fae1decce03e1f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 7 Oct 2020 15:15:57 +0100 Subject: Reduce serialization errors in MultiWriterIdGen (#8456) We call `_update_stream_positions_table_txn` a lot, which is an UPSERT that can conflict in `REPEATABLE READ` isolation level. Instead of doing a transaction consisting of a single query we may as well run it outside of a transaction. --- changelog.d/8456.misc | 1 + synapse/storage/database.py | 69 ++++++++++++++++++++++++++++++++--- synapse/storage/engines/_base.py | 17 +++++++++ synapse/storage/engines/postgres.py | 10 ++++- synapse/storage/engines/sqlite.py | 10 +++++ synapse/storage/util/id_generators.py | 12 +++++- tests/storage/test_base.py | 1 + 7 files changed, 112 insertions(+), 8 deletions(-) create mode 100644 changelog.d/8456.misc (limited to 'synapse/storage') diff --git a/changelog.d/8456.misc b/changelog.d/8456.misc new file mode 100644 index 0000000000..ccd260069b --- /dev/null +++ b/changelog.d/8456.misc @@ -0,0 +1 @@ +Reduce number of serialization errors of `MultiWriterIdGenerator._update_table`. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0d9d9b7cc0..0ba3a025cf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -463,6 +463,24 @@ class DatabasePool: *args: Any, **kwargs: Any ) -> R: + """Start a new database transaction with the given connection. + + Note: The given func may be called multiple times under certain + failure modes. This is normally fine when in a standard transaction, + but care must be taken if the connection is in `autocommit` mode that + the function will correctly handle being aborted and retried half way + through its execution. + + Args: + conn + desc + after_callbacks + exception_callbacks + func + *args + **kwargs + """ + start = monotonic_time() txn_id = self._TXN_ID @@ -566,7 +584,12 @@ class DatabasePool: sql_txn_timer.labels(desc).observe(duration) async def runInteraction( - self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + desc: str, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Starts a transaction on the database and runs a given function @@ -576,6 +599,18 @@ class DatabasePool: database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transactions + that are only a single query. + + Currently, this is only implemented for Postgres. SQLite will still + run the function inside a transaction. + + WARNING: This means that if func fails half way through then + the changes will *not* be rolled back. `func` may also get + called multiple times if the transaction is retried, so must + correctly handle that case. + args: positional args to pass to `func` kwargs: named args to pass to `func` @@ -596,6 +631,7 @@ class DatabasePool: exception_callbacks, func, *args, + db_autocommit=db_autocommit, **kwargs ) @@ -609,7 +645,11 @@ class DatabasePool: return cast(R, result) async def runWithConnection( - self, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -618,6 +658,9 @@ class DatabasePool: database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. args: positional args to pass to `func` + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transaction + that are only a single query. Currently only affects postgres. kwargs: named args to pass to `func` Returns: @@ -633,6 +676,13 @@ class DatabasePool: start_time = monotonic_time() def inner_func(conn, *args, **kwargs): + # We shouldn't be in a transaction. If we are then something + # somewhere hasn't committed after doing work. (This is likely only + # possible during startup, as `run*` will ensure changes are + # committed/rolled back before putting the connection back in the + # pool). + assert not self.engine.in_transaction(conn) + with LoggingContext("runWithConnection", parent_context) as context: sched_duration_sec = monotonic_time() - start_time sql_scheduling_timer.observe(sched_duration_sec) @@ -642,10 +692,17 @@ class DatabasePool: logger.debug("Reconnecting closed database connection") conn.reconnect() - db_conn = LoggingDatabaseConnection( - conn, self.engine, "runWithConnection" - ) - return func(db_conn, *args, **kwargs) + try: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, True) + + db_conn = LoggingDatabaseConnection( + conn, self.engine, "runWithConnection" + ) + return func(db_conn, *args, **kwargs) + finally: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, False) return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 908cbc79e3..d6d632dc10 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -97,3 +97,20 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): """Gets a string giving the server version. For example: '3.22.0' """ ... + + @abc.abstractmethod + def in_transaction(self, conn: Connection) -> bool: + """Whether the connection is currently in a transaction. + """ + ... + + @abc.abstractmethod + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + """Attempt to set the connections autocommit mode. + + When True queries are run outside of transactions. + + Note: This has no effect on SQLite3, so callers still need to + commit/rollback the connections. + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index ff39281f85..7719ac32f7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -15,7 +15,8 @@ import logging -from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup +from synapse.storage.types import Connection logger = logging.getLogger(__name__) @@ -119,6 +120,7 @@ class PostgresEngine(BaseDatabaseEngine): cursor.execute("SET synchronous_commit TO OFF") cursor.close() + db_conn.commit() @property def can_native_upsert(self): @@ -171,3 +173,9 @@ class PostgresEngine(BaseDatabaseEngine): return "%i.%i" % (numver / 10000, numver % 10000) else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) + + def in_transaction(self, conn: Connection) -> bool: + return conn.status != self.module.extensions.STATUS_READY # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + return conn.set_session(autocommit=autocommit) # type: ignore diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 8a0f8c89d1..5db0f0b520 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -17,6 +17,7 @@ import threading import typing from synapse.storage.engines import BaseDatabaseEngine +from synapse.storage.types import Connection if typing.TYPE_CHECKING: import sqlite3 # noqa: F401 @@ -86,6 +87,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): db_conn.create_function("rank", 1, _rank) db_conn.execute("PRAGMA foreign_keys = ON;") + db_conn.commit() def is_deadlock(self, error): return False @@ -105,6 +107,14 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): """ return "%i.%i.%i" % self.module.sqlite_version_info + def in_transaction(self, conn: Connection) -> bool: + return conn.in_transaction # type: ignore + + def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + # Twisted doesn't let us set attributes on the connections, so we can't + # set the connection to autocommit mode. + pass + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 51f680d05d..d7e40aaa8b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -24,6 +24,7 @@ from typing_extensions import Deque from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator logger = logging.getLogger(__name__) @@ -548,7 +549,7 @@ class MultiWriterIdGenerator: # do. break - def _update_stream_positions_table_txn(self, txn): + def _update_stream_positions_table_txn(self, txn: Cursor): """Update the `stream_positions` table with newly persisted position. """ @@ -598,10 +599,13 @@ class _MultiWriterCtxManager: stream_ids = attr.ib(type=List[int], factory=list) async def __aenter__(self) -> Union[int, List[int]]: + # It's safe to run this in autocommit mode as fetching values from a + # sequence ignores transaction semantics anyway. self.stream_ids = await self.id_gen._db.runInteraction( "_load_next_mult_id", self.id_gen._load_next_mult_id_txn, self.multiple_ids or 1, + db_autocommit=True, ) # Assert the fetched ID is actually greater than any ID we've already @@ -632,10 +636,16 @@ class _MultiWriterCtxManager: # # We only do this on the success path so that the persisted current # position points to a persisted row with the correct instance name. + # + # We do this in autocommit mode as a) the upsert works correctly outside + # transactions and b) reduces the amount of time the rows are locked + # for. If we don't do this then we'll often hit serialization errors due + # to the fact we default to REPEATABLE READ isolation levels. if self.id_gen._writers: await self.id_gen._db.runInteraction( "MultiWriterIdGenerator._update_table", self.id_gen._update_stream_positions_table_txn, + db_autocommit=True, ) return False diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 40ba652248..eac7e4dcd2 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -56,6 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False + fake_engine.in_transaction.return_value = False db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine) db._db_pool = self.db_pool -- cgit 1.5.1 From e4f72ddc44367d0cd53e6cfc5ba310b6f55319b6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 7 Oct 2020 11:27:56 -0400 Subject: Move additional tasks to the background worker (#8458) --- changelog.d/8458.feature | 1 + synapse/app/generic_worker.py | 4 + synapse/app/phone_stats_home.py | 33 ++--- synapse/storage/databases/main/client_ips.py | 109 ++++++++------- synapse/storage/databases/main/metrics.py | 14 +- synapse/storage/databases/main/registration.py | 184 ++++++++++++------------- synapse/storage/databases/main/roommember.py | 5 +- synapse/storage/databases/main/transactions.py | 42 +++--- 8 files changed, 195 insertions(+), 197 deletions(-) create mode 100644 changelog.d/8458.feature (limited to 'synapse/storage') diff --git a/changelog.d/8458.feature b/changelog.d/8458.feature new file mode 100644 index 0000000000..542993110b --- /dev/null +++ b/changelog.d/8458.feature @@ -0,0 +1 @@ +Allow running background tasks in a separate worker process. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fc5188ce95..d53181deb1 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -127,6 +127,7 @@ from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore +from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( @@ -135,6 +136,7 @@ from synapse.storage.databases.main.monthly_active_users import ( from synapse.storage.databases.main.presence import UserPresenceState from synapse.storage.databases.main.search import SearchWorkerStore from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -466,6 +468,7 @@ class GenericWorkerSlavedStore( SlavedAccountDataStore, SlavedPusherStore, CensorEventsStore, + ClientIpWorkerStore, SlavedEventStore, SlavedKeyStore, RoomStore, @@ -481,6 +484,7 @@ class GenericWorkerSlavedStore( MediaRepositoryStore, ServerMetricsStore, SearchWorkerStore, + TransactionWorkerStore, BaseSlavedStore, ): pass diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index daed8ccfe9..8a69104a04 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import math import resource @@ -19,7 +18,10 @@ import sys from prometheus_client import Gauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) logger = logging.getLogger("synapse.app.homeserver") @@ -41,6 +43,7 @@ registered_reserved_users_mau_gauge = Gauge( ) +@wrap_as_background_process("phone_stats_home") async def phone_stats_home(hs, stats, stats_process=_stats_process): logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) @@ -143,20 +146,10 @@ def start_phone_stats_home(hs): (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) ) - def start_phone_stats_home(): - return run_as_background_process( - "phone_stats_home", phone_stats_home, hs, stats - ) - - def generate_user_daily_visit_stats(): - return run_as_background_process( - "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits - ) - # Rather than update on per session basis, batch up the requests. # If you increase the loop period, the accuracy of user_daily_visits # table will decrease - clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000) + clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000) # monthly active user limiting functionality def reap_monthly_active_users(): @@ -167,6 +160,7 @@ def start_phone_stats_home(hs): clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60) reap_monthly_active_users() + @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users(): current_mau_count = 0 current_mau_count_by_service = {} @@ -186,19 +180,14 @@ def start_phone_stats_home(hs): registered_reserved_users_mau_gauge.set(float(len(reserved_users))) max_mau_gauge.set(float(hs.config.max_mau_value)) - def start_generate_monthly_active_users(): - return run_as_background_process( - "generate_monthly_active_users", generate_monthly_active_users - ) - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: - start_generate_monthly_active_users() - clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000) + generate_monthly_active_users() + clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings if hs.config.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") - clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000) + clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats) # We need to defer this init for the cases that we daemonize # otherwise the process ID we get is that of the non-daemon process @@ -206,4 +195,4 @@ def start_phone_stats_home(hs): # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes - clock.call_later(5 * 60, start_phone_stats_home) + clock.call_later(5 * 60, phone_stats_home, hs, stats) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 239c7a949c..a25a888443 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -351,7 +351,63 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return updated -class ClientIpStore(ClientIpBackgroundUpdateStore): +class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.user_ips_max_age = hs.config.user_ips_max_age + + if hs.config.run_background_tasks and self.user_ips_max_age: + self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + + @wrap_as_background_process("prune_old_user_ips") + async def _prune_old_user_ips(self): + """Removes entries in user IPs older than the configured period. + """ + + if self.user_ips_max_age is None: + # Nothing to do + return + + if not await self.db_pool.updates.has_completed_background_update( + "devices_last_seen" + ): + # Only start pruning if we have finished populating the devices + # last seen info. + return + + # We do a slightly funky SQL delete to ensure we don't try and delete + # too much at once (as the table may be very large from before we + # started pruning). + # + # This works by finding the max last_seen that is less than the given + # time, but has no more than N rows before it, deleting all rows with + # a lesser last_seen time. (We COALESCE so that the sub-SELECT always + # returns exactly one row). + sql = """ + DELETE FROM user_ips + WHERE last_seen <= ( + SELECT COALESCE(MAX(last_seen), -1) + FROM ( + SELECT last_seen FROM user_ips + WHERE last_seen <= ? + ORDER BY last_seen ASC + LIMIT 5000 + ) AS u + ) + """ + + timestamp = self.clock.time_msec() - self.user_ips_max_age + + def _prune_old_user_ips_txn(txn): + txn.execute(sql, (timestamp,)) + + await self.db_pool.runInteraction( + "_prune_old_user_ips", _prune_old_user_ips_txn + ) + + +class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): self.client_ip_last_seen = Cache( @@ -360,8 +416,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age - # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) self._batch_row_update = {} @@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "before", "shutdown", self._update_client_ips_batch ) - if self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): @@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): } for (access_token, ip), (user_agent, last_seen) in results.items() ] - - @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ - - if self.user_ips_max_age is None: - # Nothing to do - return - - if not await self.db_pool.updates.has_completed_background_update( - "devices_last_seen" - ): - # Only start pruning if we have finished populating the devices - # last seen info. - return - - # We do a slightly funky SQL delete to ensure we don't try and delete - # too much at once (as the table may be very large from before we - # started pruning). - # - # This works by finding the max last_seen that is less than the given - # time, but has no more than N rows before it, deleting all rows with - # a lesser last_seen time. (We COALESCE so that the sub-SELECT always - # returns exactly one row). - sql = """ - DELETE FROM user_ips - WHERE last_seen <= ( - SELECT COALESCE(MAX(last_seen), -1) - FROM ( - SELECT last_seen FROM user_ips - WHERE last_seen <= ? - ORDER BY last_seen ASC - LIMIT 5000 - ) AS u - ) - """ - - timestamp = self.clock.time_msec() - self.user_ips_max_age - - def _prune_old_user_ips_txn(txn): - txn.execute(sql, (timestamp,)) - - await self.db_pool.runInteraction( - "_prune_old_user_ips", _prune_old_user_ips_txn - ) diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 2c5a4fdbf6..0acf0617ca 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -18,7 +18,7 @@ import time from typing import Dict from synapse.metrics import GaugeBucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_push_actions import ( @@ -57,18 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) - - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + if hs.config.run_background_tasks: + self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + @wrap_as_background_process("read_forward_extremities") async def _read_forward_extremities(self): def fetch(txn): txn.execute( @@ -274,6 +269,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 + @wrap_as_background_process("generate_user_daily_visits") async def generate_user_daily_visits(self) -> None: """ Generates daily visit data for use in cohort/ retention analysis diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 16ba545740..a85867936f 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -14,14 +14,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import re from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.types import Cursor @@ -48,6 +50,21 @@ class RegistrationWorkerStore(SQLBaseStore): database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) + self._account_validity = hs.config.account_validity + if hs.config.run_background_tasks and self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + if hs.config.run_background_tasks: + self.clock.looping_call( + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS + ) + @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( @@ -778,6 +795,78 @@ class RegistrationWorkerStore(SQLBaseStore): "delete_threepid_session", delete_threepid_session_txn ) + @wrap_as_background_process("cull_expired_threepid_validation_tokens") + async def cull_expired_threepid_validation_tokens(self) -> None: + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self.clock.time_msec(), + ) + + async def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db_pool.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + await self.db_pool.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db_pool.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -911,28 +1000,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._account_validity = hs.config.account_validity self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - - # Create a background job for culling expired 3PID validity tokens - def start_cull(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "cull_expired_threepid_validation_tokens", - self.cull_expired_threepid_validation_tokens, - ) - - hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - async def add_access_token_to_user( self, user_id: str, @@ -1477,22 +1546,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def cull_expired_threepid_validation_tokens(self) -> None: - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn(txn, ts): - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - txn.execute(sql, (ts,)) - - await self.db_pool.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), - ) - async def set_user_deactivated_status( self, user_id: str, deactivated: bool ) -> None: @@ -1522,61 +1575,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) txn.call_after(self.is_guest.invalidate, (user_id,)) - async def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.db_pool.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - await self.db_pool.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self.db_pool.simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bae1bd22d3..20fcdaa529 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -61,7 +61,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._check_safe_current_state_events_membership_updated_txn(txn) txn.close() - if self.hs.config.metrics_flags.known_servers: + if ( + self.hs.config.run_background_tasks + and self.hs.config.metrics_flags.known_servers + ): self._known_servers_count = 1 self.hs.get_clock().looping_call( run_as_background_process, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 97aed1500e..7d46090267 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple( SENTINEL = object() -class TransactionStore(SQLBaseStore): +class TransactionWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + if hs.config.run_background_tasks: + self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + + @wrap_as_background_process("cleanup_transactions") + async def _cleanup_transactions(self) -> None: + now = self._clock.time_msec() + month_ago = now - 30 * 24 * 60 * 60 * 1000 + + def _cleanup_transactions_txn(txn): + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + + await self.db_pool.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) + + +class TransactionStore(TransactionWorkerStore): """A collection of queries for handling PDUs. """ def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) - self._destination_retry_cache = ExpiringCache( cache_name="get_destination_retry_timings", clock=self._clock, @@ -266,22 +284,6 @@ class TransactionStore(SQLBaseStore): }, ) - def _start_cleanup_transactions(self): - return run_as_background_process( - "cleanup_transactions", self._cleanup_transactions - ) - - async def _cleanup_transactions(self) -> None: - now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 - - def _cleanup_transactions_txn(txn): - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - - await self.db_pool.runInteraction( - "_cleanup_transactions", _cleanup_transactions_txn - ) - async def store_destination_rooms_entries( self, destinations: Iterable[str], room_id: str, stream_ordering: int, ) -> None: -- cgit 1.5.1 From a97cec18bba42e5cb743f61e79f253d6d95a0c0c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 8 Oct 2020 13:24:46 -0400 Subject: Invalidate the cache when an olm fallback key is uploaded (#8501) --- changelog.d/8501.feature | 1 + synapse/storage/databases/main/end_to_end_keys.py | 4 ++++ tests/handlers/test_e2e_keys.py | 20 ++++++++++++++++++++ 3 files changed, 25 insertions(+) create mode 100644 changelog.d/8501.feature (limited to 'synapse/storage') diff --git a/changelog.d/8501.feature b/changelog.d/8501.feature new file mode 100644 index 0000000000..5220ddd482 --- /dev/null +++ b/changelog.d/8501.feature @@ -0,0 +1 @@ +Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)). diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 359dc6e968..4415909414 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -398,6 +398,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): desc="set_e2e_fallback_key", ) + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) + ) + @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( self, user_id: str, device_id: str diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 4e9e3dcbc2..e79d612f7a 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -33,6 +33,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler + self.store = None # type: synapse.storage.Storage @defer.inlineCallbacks def setUp(self): @@ -40,6 +41,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): self.addCleanup, handlers=None, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) + self.store = self.hs.get_datastore() @defer.inlineCallbacks def test_query_local_devices_no_devices(self): @@ -178,6 +180,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): fallback_key = {"alg1:k1": "key1"} otk = {"alg1:k2": "key2"} + # we shouldn't have any unused fallback keys yet + res = yield defer.ensureDeferred( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + yield defer.ensureDeferred( self.handler.upload_keys_for_user( local_user, @@ -186,6 +194,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) ) + # we should now have an unused alg1 key + res = yield defer.ensureDeferred( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, ["alg1"]) + # claiming an OTK when no OTKs are available should return the fallback # key res = yield defer.ensureDeferred( @@ -198,6 +212,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) + # we shouldn't have any unused fallback keys again + res = yield defer.ensureDeferred( + self.store.get_e2e_unused_fallback_key_types(local_user, device_id) + ) + self.assertEqual(res, []) + # claiming an OTK again should return the same fallback key res = yield defer.ensureDeferred( self.handler.claim_one_time_keys( -- cgit 1.5.1