From 47e28b4031c7c5e2c87824c2b4873492b996d02e Mon Sep 17 00:00:00 2001 From: Dagfinn Ilmari Mannsåker Date: Tue, 6 Jul 2021 14:31:13 +0100 Subject: Ignore EDUs for rooms we're not in (#10317) --- tests/handlers/test_typing.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'tests') diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index f58afbc244..fa3cff598e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -38,6 +38,9 @@ U_ONION = UserID.from_string("@onion:farm") # Test room id ROOM_ID = "a-room" +# Room we're not in +OTHER_ROOM_ID = "another-room" + def _expect_edu_transaction(edu_type, content, origin="test"): return { @@ -115,6 +118,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_auth().check_user_in_room = check_user_in_room + async def check_host_in_room(room_id, server_name): + return room_id == ROOM_ID + + hs.get_event_auth_handler().check_host_in_room = check_host_in_room + def get_joined_hosts_for_room(room_id): return {member.domain for member in self.room_members} @@ -244,6 +252,35 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + def test_started_typing_remote_recv_not_in_room(self): + self.room_members = [U_APPLE, U_ONION] + + self.assertEquals(self.event_source.get_current_key(), 0) + + channel = self.make_request( + "PUT", + "/_matrix/federation/v1/send/1000000", + _make_edu_transaction_json( + "m.typing", + content={ + "room_id": OTHER_ROOM_ID, + "user_id": U_ONION.to_string(), + "typing": True, + }, + ), + federation_auth_origin=b"farm", + ) + self.assertEqual(channel.code, 200) + + self.on_new_event.assert_not_called() + + self.assertEquals(self.event_source.get_current_key(), 0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0) + ) + self.assertEquals(events[0], []) + self.assertEquals(events[1], 0) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] -- cgit 1.5.1 From f6767abc054f3461cd9a70ba096fcf9a8e640edb Mon Sep 17 00:00:00 2001 From: Cristina Date: Thu, 8 Jul 2021 10:57:13 -0500 Subject: Remove functionality associated with unused historical stats tables (#9721) Fixes #9602 --- changelog.d/9721.removal | 1 + docs/room_and_user_statistics.md | 50 +---- docs/sample_config.yaml | 5 - synapse/config/stats.py | 9 - synapse/handlers/stats.py | 27 --- synapse/storage/databases/main/purge_events.py | 1 - synapse/storage/databases/main/stats.py | 291 +------------------------ synapse/storage/schema/__init__.py | 6 +- tests/handlers/test_stats.py | 203 +---------------- tests/rest/admin/test_room.py | 1 - 10 files changed, 22 insertions(+), 572 deletions(-) create mode 100644 changelog.d/9721.removal (limited to 'tests') diff --git a/changelog.d/9721.removal b/changelog.d/9721.removal new file mode 100644 index 0000000000..da2ba48c84 --- /dev/null +++ b/changelog.d/9721.removal @@ -0,0 +1 @@ +Remove functionality associated with the unused `room_stats_historical` and `user_stats_historical` tables. Contributed by @xmunoz. diff --git a/docs/room_and_user_statistics.md b/docs/room_and_user_statistics.md index e1facb38d4..cc38c890bb 100644 --- a/docs/room_and_user_statistics.md +++ b/docs/room_and_user_statistics.md @@ -1,9 +1,9 @@ Room and User Statistics ======================== -Synapse maintains room and user statistics (as well as a cache of room state), -in various tables. These can be used for administrative purposes but are also -used when generating the public room directory. +Synapse maintains room and user statistics in various tables. These can be used +for administrative purposes but are also used when generating the public room +directory. # Synapse Developer Documentation @@ -15,48 +15,8 @@ used when generating the public room directory. * **subject**: Something we are tracking stats about – currently a room or user. * **current row**: An entry for a subject in the appropriate current statistics table. Each subject can have only one. -* **historical row**: An entry for a subject in the appropriate historical - statistics table. Each subject can have any number of these. ### Overview -Stats are maintained as time series. There are two kinds of column: - -* absolute columns – where the value is correct for the time given by `end_ts` - in the stats row. (Imagine a line graph for these values) - * They can also be thought of as 'gauges' in Prometheus, if you are familiar. -* per-slice columns – where the value corresponds to how many of the occurrences - occurred within the time slice given by `(end_ts − bucket_size)…end_ts` - or `start_ts…end_ts`. (Imagine a histogram for these values) - -Stats are maintained in two tables (for each type): current and historical. - -Current stats correspond to the present values. Each subject can only have one -entry. - -Historical stats correspond to values in the past. Subjects may have multiple -entries. - -## Concepts around the management of stats - -### Current rows - -Current rows contain the most up-to-date statistics for a room. -They only contain absolute columns - -### Historical rows - -Historical rows can always be considered to be valid for the time slice and -end time specified. - -* historical rows will not exist for every time slice – they will be omitted - if there were no changes. In this case, the following assumptions can be - made to interpolate/recreate missing rows: - - absolute fields have the same values as in the preceding row - - per-slice fields are zero (`0`) -* historical rows will not be retained forever – rows older than a configurable - time will be purged. - -#### Purge - -The purging of historical rows is not yet implemented. +Stats correspond to the present values. Current rows contain the most up-to-date +statistics for a room. Each subject can only have one entry. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 71463168e3..cbbe7d58d9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2652,11 +2652,6 @@ stats: # #enabled: false - # The size of each timeslice in the room_stats_historical and - # user_stats_historical tables, as a time period. Defaults to "1d". - # - #bucket_size: 1h - # Server Notices room configuration # diff --git a/synapse/config/stats.py b/synapse/config/stats.py index 78f61fe9da..6f253e00c0 100644 --- a/synapse/config/stats.py +++ b/synapse/config/stats.py @@ -38,13 +38,9 @@ class StatsConfig(Config): def read_config(self, config, **kwargs): self.stats_enabled = True - self.stats_bucket_size = 86400 * 1000 stats_config = config.get("stats", None) if stats_config: self.stats_enabled = stats_config.get("enabled", self.stats_enabled) - self.stats_bucket_size = self.parse_duration( - stats_config.get("bucket_size", "1d") - ) if not self.stats_enabled: logger.warning(ROOM_STATS_DISABLED_WARN) @@ -59,9 +55,4 @@ class StatsConfig(Config): # correctly. # #enabled: false - - # The size of each timeslice in the room_stats_historical and - # user_stats_historical tables, as a time period. Defaults to "1d". - # - #bucket_size: 1h """ diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 4e45d1da57..814d08efcb 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -45,7 +45,6 @@ class StatsHandler: self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.stats_bucket_size = hs.config.stats_bucket_size self.stats_enabled = hs.config.stats_enabled @@ -106,20 +105,6 @@ class StatsHandler: room_deltas = {} user_deltas = {} - # Then count deltas for total_events and total_event_bytes. - ( - room_count, - user_count, - ) = await self.store.get_changes_room_total_events_and_bytes( - self.pos, max_pos - ) - - for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, Counter()).update(fields) - - for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, Counter()).update(fields) - logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -181,12 +166,10 @@ class StatsHandler: event_content = {} # type: JsonDict - sender = None if event_id is not None: event = await self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} - sender = event.sender # All the values in this dict are deltas (RELATIVE changes) room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter()) @@ -244,12 +227,6 @@ class StatsHandler: room_stats_delta["joined_members"] += 1 elif membership == Membership.INVITE: room_stats_delta["invited_members"] += 1 - - if sender and self.is_mine_id(sender): - user_to_stats_deltas.setdefault(sender, Counter())[ - "invites_sent" - ] += 1 - elif membership == Membership.LEAVE: room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: @@ -279,10 +256,6 @@ class StatsHandler: room_state["is_federatable"] = ( event_content.get("m.federate", True) is True ) - if sender and self.is_mine_id(sender): - user_to_stats_deltas.setdefault(sender, Counter())[ - "rooms_created" - ] += 1 elif typ == EventTypes.JoinRules: room_state["join_rules"] = event_content.get("join_rule") elif typ == EventTypes.RoomHistoryVisibility: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 7fb7780d0f..ec6b1eb5d4 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -392,7 +392,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 82a1833509..b10bee6daf 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -26,7 +26,6 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import StoreError from synapse.storage.database import DatabasePool from synapse.storage.databases.main.state_deltas import StateDeltasStore -from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -49,14 +48,6 @@ ABSOLUTE_STATS_FIELDS = { "user": ("joined_rooms",), } -# these fields are per-timeslice and so should be reset to 0 upon a new slice -# You can draw these stats on a histogram. -# Example: number of events sent locally during a time slice -PER_SLICE_FIELDS = { - "room": ("total_events", "total_event_bytes"), - "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"), -} - TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} # these are the tables (& ID columns) which contain our actual subjects @@ -106,7 +97,6 @@ class StatsStore(StateDeltasStore): self.server_name = hs.hostname self.clock = self.hs.get_clock() self.stats_enabled = hs.config.stats_enabled - self.stats_bucket_size = hs.config.stats_bucket_size self.stats_delta_processing_lock = DeferredLock() @@ -122,22 +112,6 @@ class StatsStore(StateDeltasStore): self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") self.db_pool.updates.register_noop_background_update("populate_stats_prepare") - def quantise_stats_time(self, ts): - """ - Quantises a timestamp to be a multiple of the bucket size. - - Args: - ts (int): the timestamp to quantise, in milliseconds since the Unix - Epoch - - Returns: - int: a timestamp which - - is divisible by the bucket size; - - is no later than `ts`; and - - is the largest such timestamp. - """ - return (ts // self.stats_bucket_size) * self.stats_bucket_size - async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. @@ -288,56 +262,6 @@ class StatsStore(StateDeltasStore): desc="update_room_state", ) - async def get_statistics_for_subject( - self, stats_type: str, stats_id: str, start: str, size: int = 100 - ) -> List[dict]: - """ - Get statistics for a given subject. - - Args: - stats_type: The type of subject - stats_id: The ID of the subject (e.g. room_id or user_id) - start: Pagination start. Number of entries, not timestamp. - size: How many entries to return. - - Returns: - A list of dicts, where the dict has the keys of - ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". - """ - return await self.db_pool.runInteraction( - "get_statistics_for_subject", - self._get_statistics_for_subject_txn, - stats_type, - stats_id, - start, - size, - ) - - def _get_statistics_for_subject_txn( - self, txn, stats_type, stats_id, start, size=100 - ): - """ - Transaction-bound version of L{get_statistics_for_subject}. - """ - - table, id_col = TYPE_TO_TABLE[stats_type] - selected_columns = list( - ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] - ) - - slice_list = self.db_pool.simple_select_list_paginate_txn( - txn, - table + "_historical", - "end_ts", - start, - size, - retcols=selected_columns + ["bucket_size", "end_ts"], - keyvalues={id_col: stats_id}, - order_direction="DESC", - ) - - return slice_list - @cached() async def get_earliest_token_for_stats( self, stats_type: str, id: str @@ -451,14 +375,10 @@ class StatsStore(StateDeltasStore): table, id_col = TYPE_TO_TABLE[stats_type] - quantised_ts = self.quantise_stats_time(int(ts)) - end_ts = quantised_ts + self.stats_bucket_size - # Lets be paranoid and check that all the given field names are known abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type] - slice_field_names = PER_SLICE_FIELDS[stats_type] for field in chain(fields.keys(), absolute_field_overrides.keys()): - if field not in abs_field_names and field not in slice_field_names: + if field not in abs_field_names: # guard against potential SQL injection dodginess raise ValueError( "%s is not a recognised field" @@ -491,20 +411,6 @@ class StatsStore(StateDeltasStore): additive_relatives=deltas_of_absolute_fields, ) - per_slice_additive_relatives = { - key: fields.get(key, 0) for key in slice_field_names - } - self._upsert_copy_from_table_with_additive_relatives_txn( - txn=txn, - into_table=table + "_historical", - keyvalues={id_col: stats_id}, - extra_dst_insvalues={"bucket_size": self.stats_bucket_size}, - extra_dst_keyvalues={"end_ts": end_ts}, - additive_relatives=per_slice_additive_relatives, - src_table=table + "_current", - copy_columns=abs_field_names, - ) - def _upsert_with_additive_relatives_txn( self, txn, table, keyvalues, absolutes, additive_relatives ): @@ -572,201 +478,6 @@ class StatsStore(StateDeltasStore): current_row.update(absolutes) self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) - def _upsert_copy_from_table_with_additive_relatives_txn( - self, - txn, - into_table, - keyvalues, - extra_dst_keyvalues, - extra_dst_insvalues, - additive_relatives, - src_table, - copy_columns, - ): - """Updates the historic stats table with latest updates. - - This involves copying "absolute" fields from the `_current` table, and - adding relative fields to any existing values. - - Args: - txn: Transaction - into_table (str): The destination table to UPSERT the row into - keyvalues (dict[str, any]): Row-identifying key values - extra_dst_keyvalues (dict[str, any]): Additional keyvalues - for `into_table`. - extra_dst_insvalues (dict[str, any]): Additional values to insert - on new row creation for `into_table`. - additive_relatives (dict[str, any]): Fields that will be added onto - if existing row present. (Must be disjoint from copy_columns.) - src_table (str): The source table to copy from - copy_columns (iterable[str]): The list of columns to copy - """ - if self.database_engine.can_native_upsert: - ins_columns = chain( - keyvalues, - copy_columns, - additive_relatives, - extra_dst_keyvalues, - extra_dst_insvalues, - ) - sel_exprs = chain( - keyvalues, - copy_columns, - ( - "?" - for _ in chain( - additive_relatives, extra_dst_keyvalues, extra_dst_insvalues - ) - ), - ) - keyvalues_where = ("%s = ?" % f for f in keyvalues) - - sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns) - sets_ar = ( - "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f) - for f in additive_relatives - ) - - sql = """ - INSERT INTO %(into_table)s (%(ins_columns)s) - SELECT %(sel_exprs)s - FROM %(src_table)s - WHERE %(keyvalues_where)s - ON CONFLICT (%(keyvalues)s) - DO UPDATE SET %(sets)s - """ % { - "into_table": into_table, - "ins_columns": ", ".join(ins_columns), - "sel_exprs": ", ".join(sel_exprs), - "keyvalues_where": " AND ".join(keyvalues_where), - "src_table": src_table, - "keyvalues": ", ".join( - chain(keyvalues.keys(), extra_dst_keyvalues.keys()) - ), - "sets": ", ".join(chain(sets_cc, sets_ar)), - } - - qargs = list( - chain( - additive_relatives.values(), - extra_dst_keyvalues.values(), - extra_dst_insvalues.values(), - keyvalues.values(), - ) - ) - txn.execute(sql, qargs) - else: - self.database_engine.lock_table(txn, into_table) - src_row = self.db_pool.simple_select_one_txn( - txn, src_table, keyvalues, copy_columns - ) - all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self.db_pool.simple_select_one_txn( - txn, - into_table, - keyvalues=all_dest_keyvalues, - retcols=list(chain(additive_relatives.keys(), copy_columns)), - allow_none=True, - ) - - if dest_current_row is None: - merged_dict = { - **keyvalues, - **extra_dst_keyvalues, - **extra_dst_insvalues, - **src_row, - **additive_relatives, - } - self.db_pool.simple_insert_txn(txn, into_table, merged_dict) - else: - for (key, val) in additive_relatives.items(): - src_row[key] = dest_current_row[key] + val - self.db_pool.simple_update_txn( - txn, into_table, all_dest_keyvalues, src_row - ) - - async def get_changes_room_total_events_and_bytes( - self, min_pos: int, max_pos: int - ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: - """Fetches the counts of events in the given range of stream IDs. - - Args: - min_pos - max_pos - - Returns: - Mapping of room ID to field changes. - """ - - return await self.db_pool.runInteraction( - "stats_incremental_total_events_and_bytes", - self.get_changes_room_total_events_and_bytes_txn, - min_pos, - max_pos, - ) - - def get_changes_room_total_events_and_bytes_txn( - self, txn, low_pos: int, high_pos: int - ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: - """Gets the total_events and total_event_bytes counts for rooms and - senders, in a range of stream_orderings (including backfilled events). - - Args: - txn - low_pos: Low stream ordering - high_pos: High stream ordering - - Returns: - The room and user deltas for total_events/total_event_bytes in the - format of `stats_id` -> fields - """ - - if low_pos >= high_pos: - # nothing to do here. - return {}, {} - - if isinstance(self.database_engine, PostgresEngine): - new_bytes_expression = "OCTET_LENGTH(json)" - else: - new_bytes_expression = "LENGTH(CAST(json AS BLOB))" - - sql = """ - SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.room_id - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - room_deltas = { - room_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for room_id, new_events, new_bytes in txn - } - - sql = """ - SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes - FROM events INNER JOIN event_json USING (event_id) - WHERE (? < stream_ordering AND stream_ordering <= ?) - OR (? <= stream_ordering AND stream_ordering <= ?) - GROUP BY events.sender - """ % ( - new_bytes_expression, - ) - - txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos)) - - user_deltas = { - user_id: {"total_events": new_events, "total_event_bytes": new_bytes} - for user_id, new_events, new_bytes in txn - if self.hs.is_mine_id(user_id) - } - - return room_deltas, user_deltas - async def _calculate_and_set_initial_state_for_room( self, room_id: str ) -> Tuple[dict, dict, int]: diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 0a53b73ccc..36340a652a 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 60 +SCHEMA_VERSION = 61 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -21,6 +21,10 @@ older versions of Synapse). See `README.md `_ for more information on how this works. + +Changes in SCHEMA_VERSION = 61: + - The `user_stats_historical` and `room_stats_historical` tables are not written and + are not read (previously, they were written but not read). """ diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index c9d4fd9336..e4059acda3 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -88,16 +88,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): def _get_current_stats(self, stats_type, stat_id): table, id_col = stats.TYPE_TO_TABLE[stats_type] - cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( - stats.PER_SLICE_FIELDS[stats_type] - ) - - end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) + cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) return self.get_success( self.store.db_pool.simple_select_one( - table + "_historical", - {id_col: stat_id, end_ts: end_ts}, + table + "_current", + {id_col: stat_id}, cols, allow_none=True, ) @@ -156,115 +152,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 1) self.assertEqual(r[0]["topic"], "foo") - def test_initial_earliest_token(self): - """ - Ingestion via notify_new_event will ignore tokens that the background - update have already processed. - """ - - self.reactor.advance(86401) - - self.hs.config.stats_enabled = False - self.handler.stats_enabled = False - - u1 = self.register_user("u1", "pass") - u1_token = self.login("u1", "pass") - - u2 = self.register_user("u2", "pass") - u2_token = self.login("u2", "pass") - - u3 = self.register_user("u3", "pass") - u3_token = self.login("u3", "pass") - - room_1 = self.helper.create_room_as(u1, tok=u1_token) - self.helper.send_state( - room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token - ) - - # Begin the ingestion by creating the temp tables. This will also store - # the position that the deltas should begin at, once they take over. - self.hs.config.stats_enabled = True - self.handler.stats_enabled = True - self.store.db_pool.updates._all_done = False - self.get_success( - self.store.db_pool.simple_update_one( - table="stats_incremental_position", - keyvalues={}, - updatevalues={"stream_id": 0}, - ) - ) - - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "populate_stats_prepare", "progress_json": "{}"}, - ) - ) - - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - # Now, before the table is actually ingested, add some more events. - self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) - self.helper.join(room=room_1, user=u2, tok=u2_token) - - # orig_delta_processor = self.store. - - # Now do the initial ingestion. - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_stats_cleanup", - "progress_json": "{}", - "depends_on": "populate_stats_process_rooms", - }, - ) - ) - - self.store.db_pool.updates._all_done = False - while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() - ): - self.get_success( - self.store.db_pool.updates.do_next_background_update(100), by=0.1 - ) - - self.reactor.advance(86401) - - # Now add some more events, triggering ingestion. Because of the stream - # position being set to before the events sent in the middle, a simpler - # implementation would reprocess those events, and say there were four - # users, not three. - self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) - self.helper.join(room=room_1, user=u3, tok=u3_token) - - # self.handler.notify_new_event() - - # We need to let the delta processor advance… - self.reactor.advance(10 * 60) - - # Get the slices! There should be two -- day 1, and day 2. - r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) - - self.assertEqual(len(r), 2) - - # The oldest has 2 joined members - self.assertEqual(r[-1]["joined_members"], 2) - - # The newest has 3 - self.assertEqual(r[0]["joined_members"], 3) - def test_create_user(self): """ When we create a user, it should have statistics already ready. @@ -296,22 +183,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNotNone(r1stats) self.assertIsNotNone(r2stats) - # contains the default things you'd expect in a fresh room - self.assertEqual( - r1stats["total_events"], - EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM, - "Wrong number of total_events in new room's stats!" - " You may need to update this if more state events are added to" - " the room creation process.", - ) - self.assertEqual( - r2stats["total_events"], - EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM, - "Wrong number of total_events in new room's stats!" - " You may need to update this if more state events are added to" - " the room creation process.", - ) - self.assertEqual( r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM ) @@ -327,24 +198,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(r2stats["invited_members"], 0) self.assertEqual(r2stats["banned_members"], 0) - def test_send_message_increments_total_events(self): - """ - When we send a message, it increments total_events. - """ - - self._perform_background_initial_update() - - u1 = self.register_user("u1", "pass") - u1token = self.login("u1", "pass") - r1 = self.helper.create_room_as(u1, tok=u1token) - r1stats_ante = self._get_current_stats("room", r1) - - self.helper.send(r1, "hiss", tok=u1token) - - r1stats_post = self._get_current_stats("room", r1) - - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) - def test_updating_profile_information_does_not_increase_joined_members_count(self): """ Check that the joined_members count does not increase when a user changes their @@ -378,7 +231,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_send_state_event_nonoverwriting(self): """ - When we send a non-overwriting state event, it increments total_events AND current_state_events + When we send a non-overwriting state event, it increments current_state_events """ self._perform_background_initial_update() @@ -399,44 +252,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, ) - def test_send_state_event_overwriting(self): - """ - When we send an overwriting state event, it increments total_events ONLY - """ - - self._perform_background_initial_update() - - u1 = self.register_user("u1", "pass") - u1token = self.login("u1", "pass") - r1 = self.helper.create_room_as(u1, tok=u1token) - - self.helper.send_state( - r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby" - ) - - r1stats_ante = self._get_current_stats("room", r1) - - self.helper.send_state( - r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby" - ) - - r1stats_post = self._get_current_stats("room", r1) - - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) - self.assertEqual( - r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], - 0, - ) - def test_join_first_time(self): """ - When a user joins a room for the first time, total_events, current_state_events and + When a user joins a room for the first time, current_state_events and joined_members should increase by exactly 1. """ @@ -455,7 +278,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, @@ -466,7 +288,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_join_after_leave(self): """ - When a user joins a room after being previously left, total_events and + When a user joins a room after being previously left, joined_members should increase by exactly 1. current_state_events should not increase. left_members should decrease by exactly 1. @@ -490,7 +312,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -504,7 +325,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_invited(self): """ - When a user invites another user, current_state_events, total_events and + When a user invites another user, current_state_events and invited_members should increase by exactly 1. """ @@ -522,7 +343,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 1, @@ -533,7 +353,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_join_after_invite(self): """ - When a user joins a room after being invited, total_events and + When a user joins a room after being invited and joined_members should increase by exactly 1. current_state_events should not increase. invited_members should decrease by exactly 1. @@ -556,7 +376,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -570,7 +389,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_left(self): """ - When a user leaves a room after joining, total_events and + When a user leaves a room after joining and left_members should increase by exactly 1. current_state_events should not increase. joined_members should decrease by exactly 1. @@ -593,7 +412,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, @@ -607,7 +425,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): def test_banned(self): """ - When a user is banned from a room after joining, total_events and + When a user is banned from a room after joining and left_members should increase by exactly 1. current_state_events should not increase. banned_members should decrease by exactly 1. @@ -630,7 +448,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): r1stats_post = self._get_current_stats("room", r1) - self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1) self.assertEqual( r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], 0, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ee071c2477..959d3cea77 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1753,7 +1753,6 @@ PURGE_TABLES = [ "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", -- cgit 1.5.1 From 19d0401c56a8f31441c65e62ffd688f615536d76 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 12 Jul 2021 11:21:04 -0400 Subject: Additional unit tests for spaces summary. (#10305) --- changelog.d/10305.misc | 1 + tests/handlers/test_space_summary.py | 204 ++++++++++++++++++++++++++++++++++- 2 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10305.misc (limited to 'tests') diff --git a/changelog.d/10305.misc b/changelog.d/10305.misc new file mode 100644 index 0000000000..8488d47f6f --- /dev/null +++ b/changelog.d/10305.misc @@ -0,0 +1 @@ +Additional unit tests for the spaces summary API. diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 9771d3fb3b..faed1f1a18 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -14,7 +14,7 @@ from typing import Any, Iterable, Optional, Tuple from unittest import mock -from synapse.api.constants import EventContentFields, RoomTypes +from synapse.api.constants import EventContentFields, JoinRules, RoomTypes from synapse.api.errors import AuthError from synapse.handlers.space_summary import _child_events_comparison_key from synapse.rest import admin @@ -178,3 +178,205 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, [self.space]) self._assert_events(result, [(self.space, self.room)]) + + def test_complex_space(self): + """ + Create a "complex" space to see how it handles things like loops and subspaces. + """ + # Create an inaccessible room. + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + room2 = self.helper.create_room_as(user2, tok=token2) + # This is a bit odd as "user" is adding a room they don't know about, but + # it works for the tests. + self._add_child(self.space, room2, self.token) + + # Create a subspace under the space with an additional room in it. + subspace = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + subroom = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, subspace, token=self.token) + self._add_child(subspace, subroom, token=self.token) + # Also add the two rooms from the space into this subspace (causing loops). + self._add_child(subspace, self.room, token=self.token) + self._add_child(subspace, room2, self.token) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + + # The result should include each room a single time and each link. + self._assert_rooms(result, [self.space, self.room, subspace, subroom]) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, room2), + (self.space, subspace), + (subspace, subroom), + (subspace, self.room), + (subspace, room2), + ], + ) + + def test_fed_complex(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + subroom = "#subroom:" + fed_hostname + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + # Return some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + rooms = [ + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + { + "room_id": subroom, + "world_readable": True, + }, + ] + event_content = {"via": [fed_hostname]} + events = [ + { + "room_id": subspace, + "state_key": subroom, + "content": event_content, + }, + ] + return rooms, events + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms(result, [self.space, self.room, subspace, subroom]) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, subspace), + (subspace, subroom), + ], + ) + + def test_fed_filtering(self): + """ + Rooms returned over federation should be properly filtered to only include + rooms the user has access to. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + + # Create a few rooms which will have different properties. + restricted_room = "#restricted:" + fed_hostname + restricted_accessible_room = "#restricted_accessible:" + fed_hostname + world_readable_room = "#world_readable:" + fed_hostname + joined_room = self.helper.create_room_as(self.user, tok=self.token) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + # Note that these entries are brief, but should contain enough info. + rooms = [ + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [], + }, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [self.room], + }, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ] + + # Place each room in the sub-space. + event_content = {"via": [fed_hostname]} + events = [ + { + "room_id": subspace, + "state_key": room["room_id"], + "content": event_content, + } + for room in rooms + ] + + # Also include the subspace. + rooms.insert( + 0, + { + "room_id": subspace, + "world_readable": True, + }, + ) + return rooms, events + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms( + result, + [ + self.space, + self.room, + subspace, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, subspace), + (subspace, restricted_room), + (subspace, restricted_accessible_room), + (subspace, world_readable_room), + (subspace, joined_room), + ], + ) -- cgit 1.5.1 From 89cfc3dd9849b0580146151098ad039a7680c63f Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 12:43:15 +0200 Subject: [pyupgrade] `tests/` (#10347) --- changelog.d/10347.misc | 1 + tests/config/test_load.py | 4 ++-- tests/handlers/test_profile.py | 2 +- .../http/federation/test_matrix_federation_agent.py | 2 +- tests/http/test_fedclient.py | 8 +++----- tests/replication/_base.py | 6 +++--- tests/replication/test_multi_media_repo.py | 4 ++-- tests/replication/test_sharded_event_persister.py | 6 +++--- tests/rest/admin/test_admin.py | 6 ++---- tests/rest/admin/test_room.py | 20 ++++++++++---------- tests/rest/client/v1/test_rooms.py | 14 +++++++------- tests/rest/client/v2_alpha/test_relations.py | 2 +- tests/rest/client/v2_alpha/test_report_event.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/storage/test_directory.py | 2 +- tests/storage/test_profile.py | 12 ++---------- tests/storage/test_purge.py | 2 +- tests/storage/test_room.py | 2 +- tests/test_types.py | 4 +--- tests/unittest.py | 2 +- 20 files changed, 45 insertions(+), 58 deletions(-) create mode 100644 changelog.d/10347.misc (limited to 'tests') diff --git a/changelog.d/10347.misc b/changelog.d/10347.misc new file mode 100644 index 0000000000..b2275a1350 --- /dev/null +++ b/changelog.d/10347.misc @@ -0,0 +1 @@ +Run `pyupgrade` on the codebase. \ No newline at end of file diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ebe2c05165..903c69127d 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) @@ -120,7 +120,7 @@ class ConfigLoadingTestCase(unittest.TestCase): def generate_config_and_remove_lines_containing(self, needle): self.generate_config() - with open(self.file, "r") as f: + with open(self.file) as f: contents = f.readlines() contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index cdb41101b3..2928c4f48c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -103,7 +103,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - (self.get_success(self.store.get_profile_displayname(self.frank.localpart))) + self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) def test_set_my_name_if_disabled(self): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index e45980316b..a37bce08c3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -273,7 +273,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode("ascii")) + request.write(b'{ "a": 1 }') request.finish() self.reactor.pump((0.1,)) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index ed9a884d76..d9a8b077d3 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -102,7 +102,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(test_d) # Send it the HTTP response - res_json = '{ "a": 1 }'.encode("ascii") + res_json = b'{ "a": 1 }' protocol.dataReceived( b"HTTP/1.1 200 OK\r\n" b"Server: Fake\r\n" @@ -339,10 +339,8 @@ class FederationClientTests(HomeserverTestCase): # Send it the HTTP response client.dataReceived( - ( - b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n" - ) + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" ) # Push by enough to time it out diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 624bd1b927..386ea70a25 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -550,12 +550,12 @@ class FakeRedisPubSubProtocol(Protocol): if obj is None: return "$-1\r\n" if isinstance(obj, str): - return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + return f"${len(obj)}\r\n{obj}\r\n" if isinstance(obj, int): - return ":{val}\r\n".format(val=obj) + return f":{obj}\r\n" if isinstance(obj, (list, tuple)): items = "".join(self.encode(a) for a in obj) - return "*{len}\r\n{items}".format(len=len(obj), items=items) + return f"*{len(obj)}\r\n{items}" raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 76e6644353..b42f1288eb 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -70,7 +70,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, FakeSite(resource), "GET", - "/{}/{}".format(target, media_id), + f"/{target}/{media_id}", shorthand=False, access_token=self.access_token, await_result=False, @@ -113,7 +113,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 5eca5c165d..f3615af97e 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -211,7 +211,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) @@ -241,7 +241,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(vector_clock_token), + f"/sync?since={vector_clock_token}", access_token=access_token, ) @@ -266,7 +266,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.reactor, sync_hs_site, "GET", - "/sync?since={}".format(next_batch), + f"/sync?since={next_batch}", access_token=access_token, ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 2f7090e554..a7c6e595b9 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -66,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Create a new group channel = self.make_request( "POST", - "/create_group".encode("ascii"), + b"/create_group", access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -129,9 +129,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" - channel = self.make_request( - "GET", "/joined_groups".encode("ascii"), access_token=access_token - ) + channel = self.make_request("GET", b"/joined_groups", access_token=access_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 959d3cea77..17ec8bfd3b 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -535,7 +535,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") def _assert_peek(self, room_id, expect_code): """Assert that the admin user can (or cannot) peek into the room.""" @@ -599,7 +599,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") class RoomTestCase(unittest.HomeserverTestCase): @@ -1280,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.public_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" def test_requester_is_no_admin(self): """ @@ -1420,7 +1420,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1463,7 +1463,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Join user to room. - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1493,7 +1493,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1633,7 +1633,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1660,7 +1660,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1686,7 +1686,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1720,7 +1720,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e94566ffd7..3df070c936 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/join", content={"reason": reason}, access_token=self.second_tok, ) @@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) @@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/kick", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) @@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/ban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/unban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/invite", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 856aa8682f..2e2f94742e 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): from_token = "" if prev_token: diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index 1ec6b05e5b..a76a6fef1e 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -41,7 +41,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) resp = self.helper.send(self.room_id, tok=self.admin_user_tok) self.event_id = resp["event_id"] - self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" def test_reason_str_and_score_int(self): data = {"reason": "this makes me sad", "score": -100} diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 95e7075841..2d6b49692e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -310,7 +310,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote("\u2603".encode("utf8")).encode("ascii") + filename = parse.quote("\u2603".encode()).encode("ascii") channel = self._req( b"inline; filename*=utf-8''" + filename + self.test_image.extension ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 41bef62ca8..43628ce44f 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -59,5 +59,5 @@ class DirectoryStoreTestCase(HomeserverTestCase): self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (self.get_success(self.store.get_association_from_room_alias(self.alias))) + self.get_success(self.store.get_association_from_room_alias(self.alias)) ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 8a446da848..a1ba99ff14 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -45,11 +45,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_displayname(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) ) def test_avatar_url(self): @@ -76,9 +72,5 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - ( - self.get_success( - self.store.get_profile_avatar_url(self.u_frank.localpart) - ) - ) + self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) ) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 54c5b470c7..e5574063f1 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -75,7 +75,7 @@ class PurgeTests(HomeserverTestCase): token = self.get_success( self.store.get_topological_token_for_event(last["event_id"]) ) - event = "t{}-{}".format(token.topological + 1, token.stream + 1) + event = f"t{token.topological + 1}-{token.stream + 1}" # Purge everything before this topological token f = self.get_failure( diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 70257bf210..31ce7f6252 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -49,7 +49,7 @@ class RoomStoreTestCase(HomeserverTestCase): ) def test_get_room_unknown_room(self): - self.assertIsNone((self.get_success(self.store.get_room("!uknown:test")))) + self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) def test_get_room_with_stats(self): self.assertDictContainsSubset( diff --git a/tests/test_types.py b/tests/test_types.py index d7881021d3..0d0c00d97a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -103,6 +103,4 @@ class MapUsernameTestCase(unittest.TestCase): def testNonAscii(self): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") - self.assertEqual( - map_username_to_mxid_localpart("têst".encode("utf-8")), "t=c3=aast" - ) + self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") diff --git a/tests/unittest.py b/tests/unittest.py index 74db7c08f1..907b94b10a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -140,7 +140,7 @@ class TestCase(unittest.TestCase): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))("Assert error for '.{}':".format(key)) from e + raise (type(e))(f"Assert error for '.{key}':") from e def assert_dict(self, required, actual): """Does a partial assert of a dict. -- cgit 1.5.1 From 93729719b8451493e1df9930feb9f02f14ea5cef Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 12:52:58 +0200 Subject: Use inline type hints in `tests/` (#10350) This PR is tantamount to running: python3.8 -m com2ann -v 6 tests/ (com2ann requires python 3.8 to run) --- changelog.d/10350.misc | 1 + tests/events/test_presence_router.py | 6 +++--- tests/module_api/test_api.py | 16 ++++++++-------- tests/replication/_base.py | 12 ++++++------ tests/replication/tcp/streams/test_events.py | 14 +++++++------- tests/replication/tcp/streams/test_receipts.py | 4 ++-- tests/replication/tcp/streams/test_typing.py | 4 ++-- tests/replication/test_multi_media_repo.py | 2 +- tests/rest/client/test_third_party_rules.py | 4 ++-- tests/rest/client/v1/test_login.py | 14 ++++++-------- tests/server.py | 8 +++++--- tests/storage/test_background_update.py | 4 +--- tests/storage/test_id_generators.py | 6 +++--- tests/test_state.py | 2 +- tests/test_utils/html_parsers.py | 6 +++--- tests/unittest.py | 2 +- tests/util/caches/test_descriptors.py | 2 +- tests/util/test_itertools.py | 18 +++++++++--------- 18 files changed, 62 insertions(+), 63 deletions(-) create mode 100644 changelog.d/10350.misc (limited to 'tests') diff --git a/changelog.d/10350.misc b/changelog.d/10350.misc new file mode 100644 index 0000000000..eed2d8552a --- /dev/null +++ b/changelog.d/10350.misc @@ -0,0 +1 @@ +Convert internal type variable syntax to reflect wider ecosystem use. \ No newline at end of file diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 875b0d0a11..c4ad33194d 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -152,7 +152,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_one_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "boop") @@ -274,7 +274,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence_updates, _ = sync_presence(self, self.other_user_id) self.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] self.assertEqual(presence_update.user_id, self.other_user_id) self.assertEqual(presence_update.state, "online") self.assertEqual(presence_update.status_msg, "I'm online!") @@ -320,7 +320,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 2c68b9a13c..81d9e2f484 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -100,9 +100,9 @@ class ModuleApiTestCase(HomeserverTestCase): "content": content, "sender": user_id, } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.message") self.assertEqual(event.room_id, room_id) @@ -136,9 +136,9 @@ class ModuleApiTestCase(HomeserverTestCase): "sender": user_id, "state_key": "", } - event = self.get_success( + event: EventBase = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) - ) # type: EventBase + ) self.assertEqual(event.sender, user_id) self.assertEqual(event.type, "m.room.power_levels") self.assertEqual(event.room_id, room_id) @@ -281,7 +281,7 @@ class ModuleApiTestCase(HomeserverTestCase): ) for call in calls: call_args = call[0] - federation_transaction = call_args[0] # type: Transaction + federation_transaction: Transaction = call_args[0] # Get the sent EDUs in this transaction edus = federation_transaction.get_dict()["edus"] @@ -390,7 +390,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -443,7 +443,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -454,7 +454,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update = presence_updates[0] # type: UserPresenceState + presence_update: UserPresenceState = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 386ea70a25..e9fd991718 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -53,9 +53,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() - self.server = server_factory.buildProtocol( + self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( None - ) # type: ServerReplicationStreamProtocol + ) # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" @@ -195,7 +195,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): fetching updates for given stream. """ - path = request.path # type: bytes # type: ignore + path: bytes = request.path # type: ignore self.assertRegex( path, br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" @@ -212,7 +212,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): unlike `BaseStreamTestCase`. """ - servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]] + servlets: List[Callable[[HomeServer, JsonResource], None]] = [] def setUp(self): super().setUp() @@ -448,7 +448,7 @@ class TestReplicationDataHandler(ReplicationDataHandler): super().__init__(hs) # list of received (stream_name, token, row) tuples - self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] + self.received_rdata_rows: List[Tuple[str, int, Any]] = [] async def on_rdata(self, stream_name, instance_name, token, rows): await super().on_rdata(stream_name, instance_name, token, rows) @@ -484,7 +484,7 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" - transport = None # type: Optional[FakeTransport] + transport: Optional[FakeTransport] = None def __init__(self, server: FakeRedisPubSubServer): self._server = server diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f51fa0a79e..666008425a 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -135,9 +135,9 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) events = [ self._inject_state_event(sender=OTHER_USER) @@ -238,7 +238,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual(row.data.event_id, pl_event.event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for stream_name, _, row in received_rows: self.assertEqual("events", stream_name) self.assertIsInstance(row, EventsStreamRow) @@ -290,11 +290,11 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point = self.get_success( + fork_point: List[str] = self.get_success( self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) - ) # type: List[str] + ) - events = [] # type: List[EventBase] + events: List[EventBase] = [] for user in user_ids: events.extend( self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) @@ -355,7 +355,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual(row.data.event_id, pl_events[i].event_id) # the state rows are unsorted - state_rows = [] # type: List[EventsStreamCurrentStateRow] + state_rows: List[EventsStreamCurrentStateRow] = [] for _ in range(STATES_PER_USER + 1): stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 7f5d932f0b..38e292c1ab 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -43,7 +43,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) @@ -75,7 +75,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow + row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0] self.assertEqual("!room2:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index ecd360c2d0..3ff5afc6e5 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) @@ -102,7 +102,7 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row: TypingStream.TypingStreamRow = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([USER_ID], row.user_ids) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index b42f1288eb..ffa425328f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -31,7 +31,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request logger = logging.getLogger(__name__) -test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory] +test_server_connection_factory: Optional[TestServerTLSConnectionFactory] = None class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e1fe72fc5d..c5e1c5458b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -233,11 +233,11 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "content": content, "sender": self.user_id, } - event = self.get_success( + event: EventBase = self.get_success( current_rules_module().module_api.create_and_send_event_into_room( event_dict ) - ) # type: EventBase + ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 605b952316..7eba69642a 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # stick the flows results in a dict by type - flow_results = {} # type: Dict[str, Any] + flow_results: Dict[str, Any] = {} for f in channel.json_body["flows"]: flow_type = f["type"] self.assertNotIn( @@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.close() # there should be a link for each href - returned_idps = [] # type: List[str] + returned_idps: List[str] = [] for link in p.links: path, query = link.split("?", 1) self.assertEqual(path, "pick_idp") @@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") assert cookie_headers - cookies = {} # type: Dict[str, str] + cookies: Dict[str, str] = {} for h in cookie_headers: key, value = h.split(";")[0].split("=", maxsplit=1) cookies[key] = value @@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode( - payload, secret, self.jwt_algorithm - ) # type: Union[str, bytes] + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) if isinstance(result, bytes): return result.decode("ascii") return result @@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") if isinstance(result, bytes): return result.decode("ascii") return result @@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie - cookies = {} # type: Dict[str,str] + cookies: Dict[str, str] = {} channel.extract_cookies(cookies) self.assertIn("username_mapping_session", cookies) session_id = cookies["username_mapping_session"] diff --git a/tests/server.py b/tests/server.py index f32d8dc375..6fddd3b305 100644 --- a/tests/server.py +++ b/tests/server.py @@ -52,7 +52,7 @@ class FakeChannel: _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) _ip = attr.ib(type=str, default="127.0.0.1") - _producer = None # type: Optional[Union[IPullProducer, IPushProducer]] + _producer: Optional[Union[IPullProducer, IPushProducer]] = None @property def json_body(self): @@ -316,8 +316,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self._tcp_callbacks = {} self._udp = [] - lookups = self.lookups = {} # type: Dict[str, str] - self._thread_callbacks = deque() # type: Deque[Callable[[], None]] + self.lookups: Dict[str, str] = {} + self._thread_callbacks: Deque[Callable[[], None]] = deque() + + lookups = self.lookups @implementer(IResolverSimple) class FakeResolver: diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 069db0edc4..0da42b5ac5 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -7,9 +7,7 @@ from tests import unittest class BackgroundUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.updates = ( - self.hs.get_datastore().db_pool.updates - ) # type: BackgroundUpdater + self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates # the base test class should have run the real bg updates for us self.assertTrue( self.get_success(self.updates.has_completed_background_updates()) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 792b1c44c1..7486078284 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -27,7 +27,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -460,7 +460,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) @@ -586,7 +586,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.db_pool = self.store.db_pool # type: DatabasePool + self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) diff --git a/tests/test_state.py b/tests/test_state.py index 62f7095873..780eba823c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase): self.store.register_events(graph.walk()) - context_store = {} # type: dict[str, EventContext] + context_store: dict[str, EventContext] = {} for event in graph.walk(): context = yield defer.ensureDeferred( diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py index 1fbb38f4be..e878af5f12 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py @@ -23,13 +23,13 @@ class TestHtmlParser(HTMLParser): super().__init__() # a list of links found in the doc - self.links = [] # type: List[str] + self.links: List[str] = [] # the values of any hidden s: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] + self.hiddens: Dict[str, Optional[str]] = {} # the values of any radio buttons: map from name to list of values - self.radios = {} # type: Dict[str, List[Optional[str]]] + self.radios: Dict[str, List[Optional[str]]] = {} def handle_starttag( self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] diff --git a/tests/unittest.py b/tests/unittest.py index 907b94b10a..c6d9064423 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase): if not isinstance(deferred, Deferred): return d - results = [] # type: list + results: list = [] deferred.addBoth(results.append) self.pump(by=by) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 0277998cbe..39947a166b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -174,7 +174,7 @@ class DescriptorTestCase(unittest.TestCase): return self.result obj = Cls() - callbacks = set() # type: Set[str] + callbacks: Set[str] = set() # set off an asynchronous request obj.result = origin_d = defer.Deferred() diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index e712eb42ea..3c0ddd4f18 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -44,7 +44,7 @@ class ChunkSeqTests(TestCase): ) def test_empty_input(self): - parts = chunk_seq([], 5) # type: Iterable[Sequence] + parts: Iterable[Sequence] = chunk_seq([], 5) self.assertEqual( list(parts), @@ -56,13 +56,13 @@ class SortTopologically(TestCase): def test_empty(self): "Test that an empty graph works correctly" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} self.assertEqual(list(sorted_topologically([], graph)), []) def test_handle_empty_graph(self): "Test that a graph where a node doesn't have an entry is treated as empty" - graph = {} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -70,7 +70,7 @@ class SortTopologically(TestCase): def test_disconnected(self): "Test that a graph with no edges work" - graph = {1: [], 2: []} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: []} # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) @@ -78,19 +78,19 @@ class SortTopologically(TestCase): def test_linear(self): "Test that a simple `4 -> 3 -> 2 -> 1` graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_subset(self): "Test that only sorting a subset of the graph works" - graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) def test_fork(self): "Test that a forked graph works" - graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should # always get the same one. @@ -98,12 +98,12 @@ class SortTopologically(TestCase): def test_duplicates(self): "Test that a graph with duplicate edges work" - graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) def test_multiple_paths(self): "Test that a graph with multiple paths between two nodes work" - graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]] + graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) -- cgit 1.5.1 From 2d16e69b4bf09b5274a8fa15c8ca4719db8366c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 13 Jul 2021 08:59:27 -0400 Subject: Show all joinable rooms in the spaces summary. (#10298) Previously only world-readable rooms were shown. This means that rooms which are public, knockable, or invite-only with a pending invitation, are included in a space summary. It also applies the same logic to the experimental room version from MSC3083 -- if a user has access to the proper allowed rooms then it is shown in the spaces summary. This change is made per MSC3173 allowing stripped state of a room to be shown to any potential room joiner. --- changelog.d/10298.feature | 1 + changelog.d/10305.feature | 1 + changelog.d/10305.misc | 1 - synapse/handlers/space_summary.py | 68 +++++++--- synapse/storage/databases/main/roommember.py | 13 +- tests/handlers/test_space_summary.py | 191 ++++++++++++++++++++++++--- 6 files changed, 237 insertions(+), 38 deletions(-) create mode 100644 changelog.d/10298.feature create mode 100644 changelog.d/10305.feature delete mode 100644 changelog.d/10305.misc (limited to 'tests') diff --git a/changelog.d/10298.feature b/changelog.d/10298.feature new file mode 100644 index 0000000000..7059db5075 --- /dev/null +++ b/changelog.d/10298.feature @@ -0,0 +1 @@ +The spaces summary API now returns any joinable rooms, not only rooms which are world-readable. diff --git a/changelog.d/10305.feature b/changelog.d/10305.feature new file mode 100644 index 0000000000..7059db5075 --- /dev/null +++ b/changelog.d/10305.feature @@ -0,0 +1 @@ +The spaces summary API now returns any joinable rooms, not only rooms which are world-readable. diff --git a/changelog.d/10305.misc b/changelog.d/10305.misc deleted file mode 100644 index 8488d47f6f..0000000000 --- a/changelog.d/10305.misc +++ /dev/null @@ -1 +0,0 @@ -Additional unit tests for the spaces summary API. diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index b585057ec3..366e6211e5 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -24,6 +24,7 @@ from synapse.api.constants import ( EventContentFields, EventTypes, HistoryVisibility, + JoinRules, Membership, RoomTypes, ) @@ -150,14 +151,21 @@ class SpaceSummaryHandler: # The room should only be included in the summary if: # a. the user is in the room; # b. the room is world readable; or - # c. the user is in a space that has been granted access to - # the room. + # c. the user could join the room, e.g. the join rules + # are set to public or the user is in a space that + # has been granted access to the room. # # Note that we know the user is not in the root room (which is # why the remote call was made in the first place), but the user # could be in one of the children rooms and we just didn't know # about the link. - include_room = room.get("world_readable") is True + + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + include_room = ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ) # Check if the user is a member of any of the allowed spaces # from the response. @@ -420,9 +428,8 @@ class SpaceSummaryHandler: It should be included if: - * The requester is joined or invited to the room. - * The requester can join without an invite (per MSC3083). - * The origin server has any user that is joined or invited to the room. + * The requester is joined or can join the room (per MSC3173). + * The origin server has any user that is joined or can join the room. * The history visibility is set to world readable. Args: @@ -441,13 +448,39 @@ class SpaceSummaryHandler: # If there's no state for the room, it isn't known. if not state_ids: + # The user might have a pending invite for the room. + if requester and await self._store.get_invite_for_local_user_in_room( + requester, room_id + ): + return True + logger.info("room %s is unknown, omitting from summary", room_id) return False room_version = await self._store.get_room_version(room_id) - # if we have an authenticated requesting user, first check if they are able to view - # stripped state in the room. + # Include the room if it has join rules of public or knock. + join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) + if join_rules_event_id: + join_rules_event = await self._store.get_event(join_rules_event_id) + join_rule = join_rules_event.content.get("join_rule") + if join_rule == JoinRules.PUBLIC or ( + room_version.msc2403_knocking and join_rule == JoinRules.KNOCK + ): + return True + + # Include the room if it is peekable. + hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_event_id: + hist_vis_ev = await self._store.get_event(hist_vis_event_id) + hist_vis = hist_vis_ev.content.get("history_visibility") + if hist_vis == HistoryVisibility.WORLD_READABLE: + return True + + # Otherwise we need to check information specific to the user or server. + + # If we have an authenticated requesting user, check if they are a member + # of the room (or can join the room). if requester: member_event_id = state_ids.get((EventTypes.Member, requester), None) @@ -470,9 +503,11 @@ class SpaceSummaryHandler: return True # If this is a request over federation, check if the host is in the room or - # is in one of the spaces specified via the join rules. + # has a user who could join the room. elif origin: - if await self._event_auth_handler.check_host_in_room(room_id, origin): + if await self._event_auth_handler.check_host_in_room( + room_id, origin + ) or await self._store.is_host_invited(room_id, origin): return True # Alternately, if the host has a user in any of the spaces specified @@ -490,18 +525,10 @@ class SpaceSummaryHandler: ): return True - # otherwise, check if the room is peekable - hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, ""), None) - if hist_vis_event_id: - hist_vis_ev = await self._store.get_event(hist_vis_event_id) - hist_vis = hist_vis_ev.content.get("history_visibility") - if hist_vis == HistoryVisibility.WORLD_READABLE: - return True - logger.info( - "room %s is unpeekable and user %s is not a member / not allowed to join, omitting from summary", + "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", room_id, - requester, + requester or origin, ) return False @@ -535,6 +562,7 @@ class SpaceSummaryHandler: "canonical_alias": stats["canonical_alias"], "num_joined_members": stats["joined_members"], "avatar_url": stats["avatar"], + "join_rules": stats["join_rules"], "world_readable": ( stats["history_visibility"] == HistoryVisibility.WORLD_READABLE ), diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2796354a1f..4d82c4c26d 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -703,13 +703,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: + return await self._check_host_room_membership(room_id, host, Membership.JOIN) + + @cached(max_entries=10000) + async def is_host_invited(self, room_id: str, host: str) -> bool: + return await self._check_host_room_membership(room_id, host, Membership.INVITE) + + async def _check_host_room_membership( + self, room_id: str, host: str, membership: str + ) -> bool: if "%" in host or "_" in host: raise Exception("Invalid host name") sql = """ SELECT state_key FROM current_state_events AS c INNER JOIN room_memberships AS m USING (event_id) - WHERE m.membership = 'join' + WHERE m.membership = ? AND type = 'm.room.member' AND c.room_id = ? AND state_key LIKE ? @@ -722,7 +731,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): like_clause = "%:" + host rows = await self.db_pool.execute( - "is_host_joined", None, sql, room_id, like_clause + "is_host_joined", None, sql, membership, room_id, like_clause ) if not rows: diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index faed1f1a18..3f73ad7f94 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -14,8 +14,18 @@ from typing import Any, Iterable, Optional, Tuple from unittest import mock -from synapse.api.constants import EventContentFields, JoinRules, RoomTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RestrictedJoinRuleTypes, + RoomTypes, +) from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict from synapse.handlers.space_summary import _child_events_comparison_key from synapse.rest import admin from synapse.rest.client.v1 import login, room @@ -117,7 +127,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): """Add a child room to a space.""" self.helper.send_state( space_id, - event_type="m.space.child", + event_type=EventTypes.SpaceChild, body={"via": [self.hs.hostname]}, tok=token, state_key=room_id, @@ -155,29 +165,129 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # The user cannot see the space. self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - # Joining the room causes it to be visible. - self.helper.join(self.space, user2, tok=token2) + # If the space is made world-readable it should return a result. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - - # The result should only have the space, but includes the link to the room. - self._assert_rooms(result, [self.space]) + self._assert_rooms(result, [self.space, self.room]) self._assert_events(result, [(self.space, self.room)]) - def test_world_readable(self): - """A world-readable room is visible to everyone.""" + # Make it not world-readable again and confirm it results in an error. self.helper.send_state( self.space, - event_type="m.room.history_visibility", - body={"history_visibility": "world_readable"}, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, tok=self.token, ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + + # Join the space and results should be returned. + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, [self.space, self.room]) + self._assert_events(result, [(self.space, self.room)]) + def _create_room_with_join_rule( + self, join_rule: str, room_version: Optional[str] = None, **extra_content + ) -> str: + """Create a room with the given join rule and add it to the space.""" + room_id = self.helper.create_room_as( + self.user, + room_version=room_version, + tok=self.token, + extra_content={ + "initial_state": [ + { + "type": EventTypes.JoinRules, + "state_key": "", + "content": { + "join_rule": join_rule, + **extra_content, + }, + } + ] + }, + ) + self._add_child(self.space, room_id, self.token) + return room_id + + def test_filtering(self): + """ + Rooms should be properly filtered to only include rooms the user has access to. + """ user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") - # The space should be visible, as well as the link to the room. + # Create a few rooms which will have different properties. + public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) + knock_room = self._create_room_with_join_rule( + JoinRules.KNOCK, room_version=RoomVersions.V7.identifier + ) + not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(invited_room, targ=user2, tok=self.token) + restricted_room = self._create_room_with_join_rule( + JoinRules.MSC3083_RESTRICTED, + room_version=RoomVersions.MSC3083.identifier, + allow=[], + ) + restricted_accessible_room = self._create_room_with_join_rule( + JoinRules.MSC3083_RESTRICTED, + room_version=RoomVersions.MSC3083.identifier, + allow=[ + { + "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, + "room_id": self.space, + "via": [self.hs.hostname], + } + ], + ) + world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.send_state( + world_readable_room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + joined_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(joined_room, targ=user2, tok=self.token) + self.helper.join(joined_room, user2, tok=token2) + + # Join the space. + self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space]) - self._assert_events(result, [(self.space, self.room)]) + + self._assert_rooms( + result, + [ + self.space, + self.room, + public_room, + knock_room, + invited_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, public_room), + (self.space, knock_room), + (self.space, not_invited_room), + (self.space, invited_room), + (self.space, restricted_room), + (self.space, restricted_accessible_room), + (self.space, world_readable_room), + (self.space, joined_room), + ], + ) def test_complex_space(self): """ @@ -186,7 +296,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Create an inaccessible room. user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") - room2 = self.helper.create_room_as(user2, tok=token2) + room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) # This is a bit odd as "user" is adding a room they don't know about, but # it works for the tests. self._add_child(self.space, room2, self.token) @@ -292,16 +402,60 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): subspace = "#subspace:" + fed_hostname # Create a few rooms which will have different properties. + public_room = "#public:" + fed_hostname + knock_room = "#knock:" + fed_hostname + not_invited_room = "#not_invited:" + fed_hostname + invited_room = "#invited:" + fed_hostname restricted_room = "#restricted:" + fed_hostname restricted_accessible_room = "#restricted_accessible:" + fed_hostname world_readable_room = "#world_readable:" + fed_hostname joined_room = self.helper.create_room_as(self.user, tok=self.token) + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + event = make_event_from_dict( + { + "room_id": invited_room, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": "@remote:" + fed_hostname, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): # Note that these entries are brief, but should contain enough info. rooms = [ + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, { "room_id": restricted_room, "world_readable": False, @@ -364,6 +518,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.space, self.room, subspace, + public_room, + knock_room, + invited_room, restricted_accessible_room, world_readable_room, joined_room, @@ -374,6 +531,10 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): [ (self.space, self.room), (self.space, subspace), + (subspace, public_room), + (subspace, knock_room), + (subspace, not_invited_room), + (subspace, invited_room), (subspace, restricted_room), (subspace, restricted_accessible_room), (subspace, world_readable_room), -- cgit 1.5.1 From eb3beb8f12a5ee93e19eacf0f03c6bcde18999fe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 14 Jul 2021 09:13:40 -0400 Subject: Add type hints and comments to event auth code. (#10393) --- changelog.d/10393.misc | 1 + mypy.ini | 1 + synapse/event_auth.py | 3 +++ tests/test_event_auth.py | 23 +++++++++++++---------- 4 files changed, 18 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10393.misc (limited to 'tests') diff --git a/changelog.d/10393.misc b/changelog.d/10393.misc new file mode 100644 index 0000000000..e80f16d607 --- /dev/null +++ b/changelog.d/10393.misc @@ -0,0 +1 @@ +Add type hints and comments to event auth code. diff --git a/mypy.ini b/mypy.ini index 72ce932d73..8717ae738e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -83,6 +83,7 @@ files = synapse/util/stringutils.py, synapse/visibility.py, tests/replication, + tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, tests/rest/client/v1/test_login.py, diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 89bcf81515..a3df6cfcc1 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -48,6 +48,9 @@ def check( room_version_obj: the version of the room event: the event being checked. auth_events: the existing room state. + do_sig_check: True if it should be verified that the sending server + signed the event. + do_size_check: True if the size of the event fields should be verified. Raises: AuthError if the checks fail diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 88888319cc..f73306ecc4 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -13,12 +13,13 @@ # limitations under the License. import unittest +from typing import Optional from synapse import event_auth from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict -from synapse.types import get_domain_from_id +from synapse.events import EventBase, make_event_from_dict +from synapse.types import JsonDict, get_domain_from_id class EventAuthTestCase(unittest.TestCase): @@ -432,7 +433,7 @@ class EventAuthTestCase(unittest.TestCase): TEST_ROOM_ID = "!test:room" -def _create_event(user_id): +def _create_event(user_id: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -444,7 +445,9 @@ def _create_event(user_id): ) -def _member_event(user_id, membership, sender=None): +def _member_event( + user_id: str, membership: str, sender: Optional[str] = None +) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -458,11 +461,11 @@ def _member_event(user_id, membership, sender=None): ) -def _join_event(user_id): +def _join_event(user_id: str) -> EventBase: return _member_event(user_id, "join") -def _power_levels_event(sender, content): +def _power_levels_event(sender: str, content: JsonDict) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -475,7 +478,7 @@ def _power_levels_event(sender, content): ) -def _alias_event(sender, **kwargs): +def _alias_event(sender: str, **kwargs) -> EventBase: data = { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), @@ -488,7 +491,7 @@ def _alias_event(sender, **kwargs): return make_event_from_dict(data) -def _random_state_event(sender): +def _random_state_event(sender: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -501,7 +504,7 @@ def _random_state_event(sender): ) -def _join_rules_event(sender, join_rule): +def _join_rules_event(sender: str, join_rule: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -519,7 +522,7 @@ def _join_rules_event(sender, join_rule): event_count = 0 -def _get_event_id(): +def _get_event_id() -> str: global event_count c = event_count event_count += 1 -- cgit 1.5.1 From c7603af1d06d65932c420ae76002b6ed94dbf23c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 15 Jul 2021 11:37:08 +0200 Subject: Allow providing credentials to `http_proxy` (#10360) --- changelog.d/10360.feature | 1 + synapse/http/proxyagent.py | 12 +++++++- tests/http/test_proxyagent.py | 65 ++++++++++++++++++++++++++++++++++--------- 3 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 changelog.d/10360.feature (limited to 'tests') diff --git a/changelog.d/10360.feature b/changelog.d/10360.feature new file mode 100644 index 0000000000..904221cb6d --- /dev/null +++ b/changelog.d/10360.feature @@ -0,0 +1 @@ +Allow providing credentials to `http_proxy`. \ No newline at end of file diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7dfae8b786..7a6a1717de 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -117,7 +117,8 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - # Parse credentials from https proxy connection string if present + # Parse credentials from http and https proxy connection string if present + self.http_proxy_creds, http_proxy = parse_username_password(http_proxy) self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) self.http_proxy_endpoint = _http_proxy_endpoint( @@ -189,6 +190,15 @@ class ProxyAgent(_AgentBase): and self.http_proxy_endpoint and not should_skip_proxy ): + # Determine whether we need to set Proxy-Authorization headers + if self.http_proxy_creds: + # Set a Proxy-Authorization header + if headers is None: + headers = Headers() + headers.addRawHeader( + b"Proxy-Authorization", + self.http_proxy_creds.as_proxy_authorization_value(), + ) # Cache *all* connections under the same key, since we are only # connecting to a single destination, the proxy: pool_key = ("http-proxy", self.http_proxy_endpoint) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index fefc8099c9..437113929a 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -205,6 +205,41 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"}) def test_http_request_via_proxy(self): + """ + Tests that requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, + ) + def test_http_request_via_proxy_with_auth(self): + """ + Tests that authenticated requests can be made through a proxy. + """ + self._do_http_request_via_proxy(auth_credentials="bob:pinkponies") + + @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) + def test_https_request_via_proxy(self): + """Tests that TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, + ) + def test_https_request_via_proxy_with_auth(self): + """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") + + def _do_http_request_via_proxy( + self, + auth_credentials: Optional[str] = None, + ): + """ + Tests that requests can be made through a proxy. + """ agent = ProxyAgent(self.reactor, use_proxy=True) self.reactor.lookups["proxy.com"] = "1.2.3.5" @@ -229,6 +264,23 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(b"bob:pinkponies") + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + self.assertEqual(request.method, b"GET") self.assertEqual(request.path, b"http://test.com") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) @@ -241,19 +293,6 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") - @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) - def test_https_request_via_proxy(self): - """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials=None) - - @patch.dict( - os.environ, - {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, - ) - def test_https_request_via_proxy_with_auth(self): - """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") - def _do_https_request_via_proxy( self, auth_credentials: Optional[str] = None, -- cgit 1.5.1 From ac5c221208ceb499cf8e9305b03efe1765ba48f6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 15 Jul 2021 11:52:56 +0100 Subject: Stagger send presence to remotes (#10398) This is to help with performance, where trying to connect to thousands of hosts at once can consume a lot of CPU (due to TLS etc). Co-authored-by: Brendan Abolivier --- changelog.d/10398.misc | 1 + synapse/federation/sender/__init__.py | 96 +++++++++++++++++++++- synapse/federation/sender/per_destination_queue.py | 16 +++- tests/events/test_presence_router.py | 8 ++ 4 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10398.misc (limited to 'tests') diff --git a/changelog.d/10398.misc b/changelog.d/10398.misc new file mode 100644 index 0000000000..326e54655a --- /dev/null +++ b/changelog.d/10398.misc @@ -0,0 +1 @@ +Stagger sending of presence update to remote servers, reducing CPU spikes caused by starting many connections to remote servers at once. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 0960f033bc..d980e0d986 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -14,9 +14,12 @@ import abc import logging +from collections import OrderedDict from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple +import attr from prometheus_client import Counter +from typing_extensions import Literal from twisted.internet import defer @@ -33,8 +36,12 @@ from synapse.metrics import ( event_processing_loop_room_count, events_processed_counter, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.util import Clock from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -137,6 +144,84 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() +@attr.s +class _PresenceQueue: + """A queue of destinations that need to be woken up due to new presence + updates. + + Staggers waking up of per destination queues to ensure that we don't attempt + to start TLS connections with many hosts all at once, leading to pinned CPU. + """ + + # The maximum duration in seconds between queuing up a destination and it + # being woken up. + _MAX_TIME_IN_QUEUE = 30.0 + + # The maximum duration in seconds between waking up consecutive destination + # queues. + _MAX_DELAY = 0.1 + + sender: "FederationSender" = attr.ib() + clock: Clock = attr.ib() + queue: "OrderedDict[str, Literal[None]]" = attr.ib(factory=OrderedDict) + processing: bool = attr.ib(default=False) + + def add_to_queue(self, destination: str) -> None: + """Add a destination to the queue to be woken up.""" + + self.queue[destination] = None + + if not self.processing: + self._handle() + + @wrap_as_background_process("_PresenceQueue.handle") + async def _handle(self) -> None: + """Background process to drain the queue.""" + + if not self.queue: + return + + assert not self.processing + self.processing = True + + try: + # We start with a delay that should drain the queue quickly enough that + # we process all destinations in the queue in _MAX_TIME_IN_QUEUE + # seconds. + # + # We also add an upper bound to the delay, to gracefully handle the + # case where the queue only has a few entries in it. + current_sleep_seconds = min( + self._MAX_DELAY, self._MAX_TIME_IN_QUEUE / len(self.queue) + ) + + while self.queue: + destination, _ = self.queue.popitem(last=False) + + queue = self.sender._get_per_destination_queue(destination) + + if not queue._new_data_to_send: + # The per destination queue has already been woken up. + continue + + queue.attempt_new_transaction() + + await self.clock.sleep(current_sleep_seconds) + + if not self.queue: + break + + # More destinations may have been added to the queue, so we may + # need to reduce the delay to ensure everything gets processed + # within _MAX_TIME_IN_QUEUE seconds. + current_sleep_seconds = min( + current_sleep_seconds, self._MAX_TIME_IN_QUEUE / len(self.queue) + ) + + finally: + self.processing = False + + class FederationSender(AbstractFederationSender): def __init__(self, hs: "HomeServer"): self.hs = hs @@ -208,6 +293,8 @@ class FederationSender(AbstractFederationSender): self._external_cache = hs.get_external_cache() + self._presence_queue = _PresenceQueue(self, self.clock) + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -517,7 +604,12 @@ class FederationSender(AbstractFederationSender): self._instance_name, destination ): continue - self._get_per_destination_queue(destination).send_presence(states) + + self._get_per_destination_queue(destination).send_presence( + states, start_loop=False + ) + + self._presence_queue.add_to_queue(destination) def build_and_send_edu( self, diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index d06a3aff19..c11d1f6d31 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -171,14 +171,24 @@ class PerDestinationQueue: self.attempt_new_transaction() - def send_presence(self, states: Iterable[UserPresenceState]) -> None: - """Add presence updates to the queue. Start the transmission loop if necessary. + def send_presence( + self, states: Iterable[UserPresenceState], start_loop: bool = True + ) -> None: + """Add presence updates to the queue. + + Args: + states: Presence updates to send + start_loop: Whether to start the transmission loop if not already + running. Args: states: presence to send """ self._pending_presence.update({state.user_id: state for state in states}) - self.attempt_new_transaction() + self._new_data_to_send = True + + if start_loop: + self.attempt_new_transaction() def queue_read_receipt(self, receipt: ReadReceipt) -> None: """Add a RR to the list to be sent. Doesn't start the transmission loop yet diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index c4ad33194d..3f41e99950 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -285,6 +285,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id) self.assertEqual(len(presence_updates), 3) + # We stagger sending of presence, so we need to wait a bit for them to + # get sent out. + self.reactor.advance(60) + # Test that sending to a remote user works remote_user_id = "@far_away_person:island" @@ -301,6 +305,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): self.module_api.send_local_online_presence_to([remote_user_id]) ) + # We stagger sending of presence, so we need to wait a bit for them to + # get sent out. + self.reactor.advance(60) + # Check that the expected presence updates were sent # We explicitly compare using sets as we expect that calling # module_api.send_local_online_presence_to will create a presence -- cgit 1.5.1 From 6a6006825067827b533b9c2b35c5a1d6a796e27c Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Thu, 15 Jul 2021 13:51:27 +0100 Subject: Add tests to characterise the current behaviour of R30 phone-home metrics (#10315) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/10315.misc | 1 + tests/app/test_phone_stats_home.py | 153 +++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 changelog.d/10315.misc create mode 100644 tests/app/test_phone_stats_home.py (limited to 'tests') diff --git a/changelog.d/10315.misc b/changelog.d/10315.misc new file mode 100644 index 0000000000..2c78644e20 --- /dev/null +++ b/changelog.d/10315.misc @@ -0,0 +1 @@ +Add tests to characterise the current behaviour of R30 phone-home metrics. diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py new file mode 100644 index 0000000000..2da6ba4dde --- /dev/null +++ b/tests/app/test_phone_stats_home.py @@ -0,0 +1,153 @@ +import synapse +from synapse.rest.client.v1 import login, room + +from tests import unittest +from tests.unittest import HomeserverTestCase + +ONE_DAY_IN_SECONDS = 86400 + + +class PhoneHomeTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + # Override the retention time for the user_ips table because otherwise it + # gets pruned too aggressively for our R30 test. + @unittest.override_config({"user_ips_max_age": "365d"}) + def test_r30_minimum_usage(self): + """ + Tests the minimum amount of interaction necessary for the R30 metric + to consider a user 'retained'. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the R30 results do not count that user. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Advance 30 days (+ 1 second, because strict inequality causes issues if we are + # bang on 30 days later). + self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait some time for _update_client_ips_batch to get + # called and update the user_ips table. + self.reactor.advance(2 * 60 * 60) + + # *Now* the user is counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance 29 days. The user has now not posted for 29 days. + self.reactor.advance(29 * ONE_DAY_IN_SECONDS) + + # The user is still counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance another day. The user has now not posted for 30 days. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # The user is now no longer counted in R30. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + def test_r30_minimum_usage_using_default_config(self): + """ + Tests the minimum amount of interaction necessary for the R30 metric + to consider a user 'retained'. + + N.B. This test does not override the `user_ips_max_age` config setting, + which defaults to 28 days. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the R30 results do not count that user. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Advance 30 days (+ 1 second, because strict inequality causes issues if we are + # bang on 30 days later). + self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait some time for _update_client_ips_batch to get + # called and update the user_ips table. + self.reactor.advance(2 * 60 * 60) + + # *Now* the user is counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance 27 days. The user has now not posted for 27 days. + self.reactor.advance(27 * ONE_DAY_IN_SECONDS) + + # The user is still counted. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + # Advance another day. The user has now not posted for 28 days. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # The user is now no longer counted in R30. + # (This is because the user_ips table has been pruned, which by default + # only preserves the last 28 days of entries.) + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + def test_r30_user_must_be_retained_for_at_least_a_month(self): + """ + Tests that a newly-registered user must be retained for a whole month + before appearing in the R30 statistic, even if they post every day + during that time! + """ + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + + # Check the user does not contribute to R30 yet. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + for _ in range(30): + # This loop posts a message every day for 30 days + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "I'm still here", tok=access_token) + + # Notice that the user *still* does not contribute to R30! + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 0}) + + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "Still here!", tok=access_token) + + # *Now* the user appears in R30. + r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "unknown": 1}) -- cgit 1.5.1 From 36dc15412de9fc1bb2ba955c8b6f2da20d2ca20f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 16 Jul 2021 18:11:53 +0200 Subject: Add a module type for account validity (#9884) This adds an API for third-party plugin modules to implement account validity, so they can provide this feature instead of Synapse. The module implementing the current behaviour for this feature can be found at https://github.com/matrix-org/synapse-email-account-validity. To allow for a smooth transition between the current feature and the new module, hooks have been added to the existing account validity endpoints to allow their behaviours to be overridden by a module. --- changelog.d/9884.feature | 1 + docs/modules.md | 47 ++++- docs/sample_config.yaml | 85 --------- synapse/api/auth.py | 17 +- synapse/config/account_validity.py | 102 ++--------- synapse/handlers/account_validity.py | 128 ++++++++++++- synapse/handlers/register.py | 5 + synapse/module_api/__init__.py | 219 +++++++++++++++++++++-- synapse/module_api/errors.py | 6 +- synapse/push/pusherpool.py | 24 +-- synapse/rest/admin/users.py | 24 ++- synapse/rest/client/v2_alpha/account_validity.py | 7 +- tests/test_state.py | 1 + 13 files changed, 438 insertions(+), 228 deletions(-) create mode 100644 changelog.d/9884.feature (limited to 'tests') diff --git a/changelog.d/9884.feature b/changelog.d/9884.feature new file mode 100644 index 0000000000..525fd2f93c --- /dev/null +++ b/changelog.d/9884.feature @@ -0,0 +1 @@ +Add a module type for the account validity feature. diff --git a/docs/modules.md b/docs/modules.md index bec1c06d15..c4cb7018f7 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -63,7 +63,7 @@ Modules can register web resources onto Synapse's web server using the following API method: ```python -def ModuleApi.register_web_resource(path: str, resource: IResource) +def ModuleApi.register_web_resource(path: str, resource: IResource) -> None ``` The path is the full absolute path to register the resource at. For example, if you @@ -91,12 +91,17 @@ are split in categories. A single module may implement callbacks from multiple c and is under no obligation to implement all callbacks from the categories it registers callbacks for. +Modules can register callbacks using one of the module API's `register_[...]_callbacks` +methods. The callback functions are passed to these methods as keyword arguments, with +the callback name as the argument name and the function as its value. This is demonstrated +in the example below. A `register_[...]_callbacks` method exists for each module type +documented in this section. + #### Spam checker callbacks -To register one of the callbacks described in this section, a module needs to use the -module API's `register_spam_checker_callbacks` method. The callback functions are passed -to `register_spam_checker_callbacks` as keyword arguments, with the callback name as the -argument name and the function as its value. This is demonstrated in the example below. +Spam checker callbacks allow module developers to implement spam mitigation actions for +Synapse instances. Spam checker callbacks can be registered using the module API's +`register_spam_checker_callbacks` method. The available spam checker callbacks are: @@ -115,7 +120,7 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (i.e. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). ```python async def user_may_create_room(user: str) -> bool @@ -188,6 +193,36 @@ async def check_media_file_for_spam( Called when storing a local or remote file. The module must return a boolean indicating whether the given file can be stored in the homeserver's media store. +#### Account validity callbacks + +Account validity callbacks allow module developers to add extra steps to verify the +validity on an account, i.e. see if a user can be granted access to their account on the +Synapse instance. Account validity callbacks can be registered using the module API's +`register_account_validity_callbacks` method. + +The available account validity callbacks are: + +```python +async def is_user_expired(user: str) -> Optional[bool] +``` + +Called when processing any authenticated request (except for logout requests). The module +can return a `bool` to indicate whether the user has expired and should be locked out of +their account, or `None` if the module wasn't able to figure it out. The user is +represented by their Matrix user ID (e.g. `@alice:example.com`). + +If the module returns `True`, the current request will be denied with the error code +`ORG_MATRIX_EXPIRED_ACCOUNT` and the HTTP status code 403. Note that this doesn't +invalidate the user's access token. + +```python +async def on_user_registration(user: str) -> None +``` + +Called after successfully registering a user, in case the module needs to perform extra +operations to keep track of them. (e.g. add them to a database table). The user is +represented by their Matrix user ID. + ### Porting an existing module that uses the old interface In order to port a module that uses Synapse's old module interface, its author needs to: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a45732a246..f4845a5841 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1310,91 +1310,6 @@ account_threepid_delegates: #auto_join_rooms_for_guests: false -## Account Validity ## - -# Optional account validity configuration. This allows for accounts to be denied -# any request after a given period. -# -# Once this feature is enabled, Synapse will look for registered users without an -# expiration date at startup and will add one to every account it found using the -# current settings at that time. -# This means that, if a validity period is set, and Synapse is restarted (it will -# then derive an expiration date from the current validity period), and some time -# after that the validity period changes and Synapse is restarted, the users' -# expiration dates won't be updated unless their account is manually renewed. This -# date will be randomly selected within a range [now + period - d ; now + period], -# where d is equal to 10% of the validity period. -# -account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - - ## Metrics ### # Enable collection and rendering of performance metrics diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 8916e6fa2f..05699714ee 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -62,6 +62,7 @@ class Auth: self.clock = hs.get_clock() self.store = hs.get_datastore() self.state = hs.get_state_handler() + self._account_validity_handler = hs.get_account_validity_handler() self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" @@ -69,9 +70,6 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users @@ -187,12 +185,17 @@ class Auth: shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity_enabled and not allow_expired: - if await self.store.is_account_expired( - user_info.user_id, self.clock.time_msec() + if not allow_expired: + if await self._account_validity_handler.is_user_expired( + user_info.user_id ): + # Raise the error if either an account validity module has determined + # the account has expired, or the legacy account validity + # implementation is enabled and determined the account has expired raise AuthError( - 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, ) device_id = user_info.device_id diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 957de7f3a6..6be4eafe55 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -18,6 +18,21 @@ class AccountValidityConfig(Config): section = "account_validity" def read_config(self, config, **kwargs): + """Parses the old account validity config. The config format looks like this: + + account_validity: + enabled: true + period: 6w + renew_at: 1w + renew_email_subject: "Renew your %(app)s account" + template_dir: "res/templates" + account_renewed_html_path: "account_renewed.html" + invalid_token_html_path: "invalid_token.html" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ account_validity_config = config.get("account_validity") or {} self.account_validity_enabled = account_validity_config.get("enabled", False) self.account_validity_renew_by_email_enabled = ( @@ -75,90 +90,3 @@ class AccountValidityConfig(Config): ], account_validity_template_dir, ) - - def generate_config_section(self, **kwargs): - return """\ - ## Account Validity ## - - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - # The currently available templates are: - # - # * account_renewed.html: Displayed to the user after they have successfully - # renewed their account. - # - # * account_previously_renewed.html: Displayed to the user if they attempt to - # renew their account with a token that is valid, but that has already - # been used. In this case the account is not renewed again. - # - # * invalid_token.html: Displayed to the user when they try to renew an account - # with an unknown or invalid renewal token. - # - # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for - # default template contents. - # - # The file name of some of these templates can be configured below for legacy - # reasons. - # - #template_dir: "res/templates" - - # A custom file name for the 'account_renewed.html' template. - # - # If not set, the file is assumed to be named "account_renewed.html". - # - #account_renewed_html_path: "account_renewed.html" - - # A custom file name for the 'invalid_token.html' template. - # - # If not set, the file is assumed to be named "invalid_token.html". - # - #invalid_token_html_path: "invalid_token.html" - """ diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d752cf34f0..078accd634 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -15,9 +15,11 @@ import email.mime.multipart import email.utils import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from synapse.api.errors import StoreError, SynapseError +from twisted.web.http import Request + +from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils @@ -27,6 +29,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Types for callbacks to be registered via the module api +IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]] +ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable] +# Temporary hooks to allow for a transition from `/_matrix/client` endpoints +# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`. +ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable] +ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]] +ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable] + class AccountValidityHandler: def __init__(self, hs: "HomeServer"): @@ -70,6 +81,99 @@ class AccountValidityHandler: if hs.config.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) + self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = [] + self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = [] + self._on_legacy_send_mail_callback: Optional[ + ON_LEGACY_SEND_MAIL_CALLBACK + ] = None + self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None + + # The legacy admin requests callback isn't a protected attribute because we need + # to access it from the admin servlet, which is outside of this handler. + self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None + + def register_account_validity_callbacks( + self, + is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, + on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, + on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, + on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, + on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, + ): + """Register callbacks from module for each hook.""" + if is_user_expired is not None: + self._is_user_expired_callbacks.append(is_user_expired) + + if on_user_registration is not None: + self._on_user_registration_callbacks.append(on_user_registration) + + # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and + # an admin one). As part of moving the feature into a module, we need to change + # the path from /_matrix/client/unstable/account_validity/... to + # /_synapse/client/account_validity, because: + # + # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix + # * the way we register servlets means that modules can't register resources + # under /_matrix/client + # + # We need to allow for a transition period between the old and new endpoints + # in order to allow for clients to update (and for emails to be processed). + # + # Once the email-account-validity module is loaded, it will take control of account + # validity by moving the rows from our `account_validity` table into its own table. + # + # Therefore, we need to allow modules (in practice just the one implementing the + # email-based account validity) to temporarily hook into the legacy endpoints so we + # can route the traffic coming into the old endpoints into the module, which is + # why we have the following three temporary hooks. + if on_legacy_send_mail is not None: + if self._on_legacy_send_mail_callback is not None: + raise RuntimeError("Tried to register on_legacy_send_mail twice") + + self._on_legacy_send_mail_callback = on_legacy_send_mail + + if on_legacy_renew is not None: + if self._on_legacy_renew_callback is not None: + raise RuntimeError("Tried to register on_legacy_renew twice") + + self._on_legacy_renew_callback = on_legacy_renew + + if on_legacy_admin_request is not None: + if self.on_legacy_admin_request_callback is not None: + raise RuntimeError("Tried to register on_legacy_admin_request twice") + + self.on_legacy_admin_request_callback = on_legacy_admin_request + + async def is_user_expired(self, user_id: str) -> bool: + """Checks if a user has expired against third-party modules. + + Args: + user_id: The user to check the expiry of. + + Returns: + Whether the user has expired. + """ + for callback in self._is_user_expired_callbacks: + expired = await callback(user_id) + if expired is not None: + return expired + + if self._account_validity_enabled: + # If no module could determine whether the user has expired and the legacy + # configuration is enabled, fall back to it. + return await self.store.is_account_expired(user_id, self.clock.time_msec()) + + return False + + async def on_user_registration(self, user_id: str): + """Tell third-party modules about a user's registration. + + Args: + user_id: The ID of the newly registered user. + """ + for callback in self._on_user_registration_callbacks: + await callback(user_id) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time @@ -95,6 +199,17 @@ class AccountValidityHandler: Raises: SynapseError if the user is not set to renew. """ + # If a module supports sending a renewal email from here, do that, otherwise do + # the legacy dance. + if self._on_legacy_send_mail_callback is not None: + await self._on_legacy_send_mail_callback(user_id) + return + + if not self._account_validity_renew_by_email_enabled: + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) # If this user isn't set to be expired, raise an error. @@ -209,6 +324,10 @@ class AccountValidityHandler: token is considered stale. A token is stale if the 'token_used_ts_ms' db column is non-null. + This method exists to support handling the legacy account validity /renew + endpoint. If a module implements the on_legacy_renew callback, then this process + is delegated to the module instead. + Args: renewal_token: Token sent with the renewal request. Returns: @@ -218,6 +337,11 @@ class AccountValidityHandler: * An int representing the user's expiry timestamp as milliseconds since the epoch, or 0 if the token was invalid. """ + # If a module supports triggering a renew from here, do that, otherwise do the + # legacy dance. + if self._on_legacy_renew_callback is not None: + return await self._on_legacy_renew_callback(renewal_token) + try: ( user_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 26ef016179..056fe5e89f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -77,6 +77,7 @@ class RegistrationHandler(BaseHandler): self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() + self._account_validity_handler = hs.get_account_validity_handler() self._server_notices_mxid = hs.config.server_notices_mxid self._server_name = hs.hostname @@ -700,6 +701,10 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) + # Only call the account validity module(s) on the main process, to avoid + # repeating e.g. database writes on all of the workers. + await self._account_validity_handler.on_user_registration(user_id) + async def register_device( self, user_id: str, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 308f045700..f3c78089b7 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -12,18 +12,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import email.utils import logging -from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, +) + +import jinja2 from twisted.internet import defer from twisted.web.resource import IResource from synapse.events import EventBase from synapse.http.client import SimpleHttpClient +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) +from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.util import Clock +from synapse.util.caches.descriptors import cached if TYPE_CHECKING: from synapse.server import HomeServer @@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi are loaded into Synapse. """ -__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"] +__all__ = [ + "errors", + "make_deferred_yieldable", + "parse_json_object_from_request", + "respond_with_html", + "run_in_background", + "cached", + "UserID", + "DatabasePool", + "LoggingTransaction", + "DirectServeHtmlResource", + "DirectServeJsonResource", + "ModuleApi", +] logger = logging.getLogger(__name__) @@ -52,12 +89,27 @@ class ModuleApi: self._server_name = hs.hostname self._presence_stream = hs.get_event_sources().sources["presence"] self._state = hs.get_state_handler() + self._clock = hs.get_clock() # type: Clock + self._send_email_handler = hs.get_send_email_handler() + + try: + app_name = self._hs.config.email_app_name + + self._from_string = self._hs.config.email_notif_from % {"app": app_name} + except (KeyError, TypeError): + # If substitution failed (which can happen if the string contains + # placeholders other than just "app", or if the type of the placeholder is + # not a string), fall back to the bare strings. + self._from_string = self._hs.config.email_notif_from + + self._raw_from = email.utils.parseaddr(self._from_string)[1] # We expose these as properties below in order to attach a helpful docstring. self._http_client: SimpleHttpClient = hs.get_simple_http_client() self._public_room_list_manager = PublicRoomListManager(hs) self._spam_checker = hs.get_spam_checker() + self._account_validity_handler = hs.get_account_validity_handler() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -67,6 +119,11 @@ class ModuleApi: """Registers callbacks for spam checking capabilities.""" return self._spam_checker.register_callbacks + @property + def register_account_validity_callbacks(self): + """Registers callbacks for account validity capabilities.""" + return self._account_validity_handler.register_account_validity_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. @@ -101,22 +158,56 @@ class ModuleApi: """ return self._public_room_list_manager - def get_user_by_req(self, req, allow_guest=False): + @property + def public_baseurl(self) -> str: + """The configured public base URL for this homeserver.""" + return self._hs.config.public_baseurl + + @property + def email_app_name(self) -> str: + """The application name configured in the homeserver's configuration.""" + return self._hs.config.email.email_app_name + + async def get_user_by_req( + self, + req: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + ) -> Requester: """Check the access_token provided for a request Args: - req (twisted.web.server.Request): Incoming HTTP request - allow_guest (bool): True if guest users should be allowed. If this + req: Incoming HTTP request + allow_guest: True if guest users should be allowed. If this is False, and the access token is for a guest user, an AuthError will be thrown + allow_expired: True if expired users should be allowed. If this + is False, and the access token is for an expired user, an + AuthError will be thrown + Returns: - twisted.internet.defer.Deferred[synapse.types.Requester]: - the requester for this request + The requester for this request + Raises: - synapse.api.errors.AuthError: if no user by that token exists, + InvalidClientCredentialsError: if no user by that token exists, or the token is invalid. """ - return self._auth.get_user_by_req(req, allow_guest) + return await self._auth.get_user_by_req( + req, + allow_guest, + allow_expired=allow_expired, + ) + + async def is_user_admin(self, user_id: str) -> bool: + """Checks if a user is a server admin. + + Args: + user_id: The Matrix ID of the user to check. + + Returns: + True if the user is a server admin, False otherwise. + """ + return await self._store.is_server_admin(UserID.from_string(user_id)) def get_qualified_user_id(self, username): """Qualify a user id, if necessary @@ -134,6 +225,32 @@ class ModuleApi: return username return UserID(username, self._hs.hostname).to_string() + async def get_profile_for_user(self, localpart: str) -> ProfileInfo: + """Look up the profile info for the user with the given localpart. + + Args: + localpart: The localpart to look up profile information for. + + Returns: + The profile information (i.e. display name and avatar URL). + """ + return await self._store.get_profileinfo(localpart) + + async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]: + """Look up the threepids (email addresses and phone numbers) associated with the + given Matrix user ID. + + Args: + user_id: The Matrix user ID to look up threepids for. + + Returns: + A list of threepids, each threepid being represented by a dictionary + containing a "medium" key which value is "email" for email addresses and + "msisdn" for phone numbers, and an "address" key which value is the + threepid's address. + """ + return await self._store.user_get_threepids(user_id) + def check_user_exists(self, user_id): """Check if user exists. @@ -464,6 +581,88 @@ class ModuleApi: presence_events, destination ) + def looping_background_call( + self, + f: Callable, + msec: float, + *args, + desc: Optional[str] = None, + **kwargs, + ): + """Wraps a function as a background process and calls it repeatedly. + + Waits `msec` initially before calling `f` for the first time. + + Args: + f: The function to call repeatedly. f can be either synchronous or + asynchronous, and must follow Synapse's logcontext rules. + More info about logcontexts is available at + https://matrix-org.github.io/synapse/latest/log_contexts.html + msec: How long to wait between calls in milliseconds. + *args: Positional arguments to pass to function. + desc: The background task's description. Default to the function's name. + **kwargs: Key arguments to pass to function. + """ + if desc is None: + desc = f.__name__ + + if self._hs.config.run_background_tasks: + self._clock.looping_call( + run_as_background_process, + msec, + desc, + f, + *args, + **kwargs, + ) + else: + logger.warning( + "Not running looping call %s as the configuration forbids it", + f, + ) + + async def send_mail( + self, + recipient: str, + subject: str, + html: str, + text: str, + ): + """Send an email on behalf of the homeserver. + + Args: + recipient: The email address for the recipient. + subject: The email's subject. + html: The email's HTML content. + text: The email's text content. + """ + await self._send_email_handler.send_email( + email_address=recipient, + subject=subject, + app_name=self.email_app_name, + html=html, + text=text, + ) + + def read_templates( + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Read and load the content of the template files at the given location. + By default, Synapse will look for these templates in its configured template + directory, but another directory to search in can be provided. + + Args: + filenames: The name of the template files to look for. + custom_template_directory: An additional directory to look for the files in. + + Returns: + A list containing the loaded templates, with the orders matching the one of + the filenames parameter. + """ + return self._hs.config.read_templates(filenames, custom_template_directory) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 02bbb0be39..98ea911a81 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -14,5 +14,9 @@ """Exception types which are exposed as part of the stable module API""" -from synapse.api.errors import RedirectException, SynapseError # noqa: F401 +from synapse.api.errors import ( # noqa: F401 + InvalidClientCredentialsError, + RedirectException, + SynapseError, +) from synapse.config._base import ConfigError # noqa: F401 diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2519ad76db..85621f33ef 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,10 +62,6 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity_enabled = ( - hs.config.account_validity.account_validity_enabled - ) - # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() @@ -89,6 +85,8 @@ class PusherPool: # map from user id to app_id:pushkey to pusher self.pushers: Dict[str, Dict[str, Pusher]] = {} + self._account_validity_handler = hs.get_account_validity_handler() + def start(self) -> None: """Starts the pushers off in a background process.""" if not self._should_start_pushers: @@ -238,12 +236,9 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): @@ -268,12 +263,9 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity_enabled: - expired = await self.store.is_account_expired( - u, self.clock.time_msec() - ) - if expired: - continue + expired = await self._account_validity_handler.is_user_expired(u) + if expired: + continue if u in self.pushers: for p in self.pushers[u].values(): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 7d75564758..06e6ccee42 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - body = parse_json_object_from_request(request) + if self.account_activity_handler.on_legacy_admin_request_callback: + expiration_ts = await ( + self.account_activity_handler.on_legacy_admin_request_callback(request) + ) + else: + body = parse_json_object_from_request(request) - if "user_id" not in body: - raise SynapseError(400, "Missing property 'user_id' in the request body") + if "user_id" not in body: + raise SynapseError( + 400, + "Missing property 'user_id' in the request body", + ) - expiration_ts = await self.account_activity_handler.renew_account_for_user( - body["user_id"], - body.get("expiration_ts"), - not body.get("enable_renewal_emails", True), - ) + expiration_ts = await self.account_activity_handler.renew_account_for_user( + body["user_id"], + body.get("expiration_ts"), + not body.get("enable_renewal_emails", True), + ) res = {"expiration_ts": expiration_ts} return 200, res diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 2d1ad3d3fb..3ebe401861 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -14,7 +14,7 @@ import logging -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import SynapseError from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet @@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet): ) async def on_POST(self, request): - if not self.account_validity_renew_by_email_enabled: - raise AuthError( - 403, "Account renewal via email is disabled on this server." - ) - requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() await self.account_activity_handler.send_renewal_email_to_user(user_id) diff --git a/tests/test_state.py b/tests/test_state.py index 780eba823c..e5488df1ac 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase): "get_state_handler", "get_clock", "get_state_resolution_handler", + "get_account_validity_handler", "hostname", ] ) -- cgit 1.5.1 From 4e340412c020f685cb402a735b983f6e332e206b Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 19 Jul 2021 16:11:34 +0100 Subject: Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric (#10332) Signed-off-by: Olivier Wilkinson (reivilibre) --- changelog.d/10332.feature | 1 + synapse/app/phone_stats_home.py | 4 + synapse/storage/databases/main/metrics.py | 129 ++++++++++++++++ tests/app/test_phone_stats_home.py | 242 ++++++++++++++++++++++++++++++ tests/rest/client/v1/utils.py | 30 +++- tests/unittest.py | 15 +- 6 files changed, 416 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10332.feature (limited to 'tests') diff --git a/changelog.d/10332.feature b/changelog.d/10332.feature new file mode 100644 index 0000000000..091947ff22 --- /dev/null +++ b/changelog.d/10332.feature @@ -0,0 +1 @@ +Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 8f86cecb76..7904c246df 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -107,6 +107,10 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): for name, count in r30_results.items(): stats["r30_users_" + name] = count + r30v2_results = await hs.get_datastore().count_r30_users() + for name, count in r30v2_results.items(): + stats["r30v2_users_" + name] = count + stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index e3a544d9b2..dc0bbc56ac 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -316,6 +316,135 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) + async def count_r30v2_users(self) -> Dict[str, int]: + """ + Counts the number of 30 day retained users, defined as users that: + - Appear more than once in the past 60 days + - Have more than 30 days between the most and least recent appearances that + occurred in the past 60 days. + + (This is the second version of this metric, hence R30'v2') + + Returns: + A mapping from client type to the number of 30-day retained users for that client. + + The dict keys are: + - "all" (a combined number of users across any and all clients) + - "android" (Element Android) + - "ios" (Element iOS) + - "electron" (Element Desktop) + - "web" (any web application -- it's not possible to distinguish Element Web here) + """ + + def _count_r30v2_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs + one_day_from_now_in_secs = now + 86400 + + # This is the 'per-platform' count. + sql = """ + SELECT + client_type, + count(client_type) + FROM + ( + SELECT + user_id, + CASE + WHEN + LOWER(user_agent) LIKE '%%riot%%' OR + LOWER(user_agent) LIKE '%%element%%' + THEN CASE + WHEN + LOWER(user_agent) LIKE '%%electron%%' + THEN 'electron' + WHEN + LOWER(user_agent) LIKE '%%android%%' + THEN 'android' + WHEN + LOWER(user_agent) LIKE '%%ios%%' + THEN 'ios' + ELSE 'unknown' + END + WHEN + LOWER(user_agent) LIKE '%%mozilla%%' OR + LOWER(user_agent) LIKE '%%gecko%%' + THEN 'web' + ELSE 'unknown' + END as client_type + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id, + client_type + HAVING + max(timestamp) - min(timestamp) > ? + ) AS temp + GROUP BY + client_type + ; + """ + + # We initialise all the client types to zero, so we get an explicit + # zero if they don't appear in the query results + results = {"ios": 0, "android": 0, "web": 0, "electron": 0} + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + + for row in txn: + if row[0] == "unknown": + continue + results[row[0]] = row[1] + + # This is the 'all users' count. + sql = """ + SELECT COUNT(*) FROM ( + SELECT + 1 + FROM + user_daily_visits + WHERE + timestamp > ? + AND + timestamp < ? + GROUP BY + user_id + HAVING + max(timestamp) - min(timestamp) > ? + ) AS r30_users + """ + + txn.execute( + sql, + ( + sixty_days_ago_in_secs * 1000, + one_day_from_now_in_secs * 1000, + thirty_days_in_secs * 1000, + ), + ) + row = txn.fetchone() + if row is None: + results["all"] = 0 + else: + results["all"] = row[0] + + return results + + return await self.db_pool.runInteraction( + "count_r30v2_users", _count_r30v2_users + ) + def _get_start_of_day(self): """ Returns millisecond unixtime for start of UTC day. diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 2da6ba4dde..5527e278db 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,9 +1,11 @@ import synapse +from synapse.app.phone_stats_home import start_phone_stats_home from synapse.rest.client.v1 import login, room from tests import unittest from tests.unittest import HomeserverTestCase +FIVE_MINUTES_IN_SECONDS = 300 ONE_DAY_IN_SECONDS = 86400 @@ -151,3 +153,243 @@ class PhoneHomeTestCase(HomeserverTestCase): # *Now* the user appears in R30. r30_results = self.get_success(self.hs.get_datastore().count_r30_users()) self.assertEqual(r30_results, {"all": 1, "unknown": 1}) + + +class PhoneHomeR30V2TestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def _advance_to(self, desired_time_secs: float): + now = self.hs.get_clock().time() + assert now < desired_time_secs + self.reactor.advance(desired_time_secs - now) + + def make_homeserver(self, reactor, clock): + hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) + + # We don't want our tests to actually report statistics, so check + # that it's not enabled + assert not hs.config.report_stats + + # This starts the needed data collection that we rely on to calculate + # R30v2 metrics. + start_phone_stats_home(hs) + return hs + + def test_r30v2_minimum_usage(self): + """ + Tests the minimum amount of interaction necessary for the R30v2 metric + to consider a user 'retained'. + """ + + # Register a user, log it in, create a room and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!") + room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) + self.helper.send(room_id, "message", tok=access_token) + first_post_at = self.hs.get_clock().time() + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the R30 results do not count that user. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance 31 days. + # (R30v2 includes users with **more** than 30 days between the two visits, + # and user_daily_visits records the timestamp as the start of the day.) + self.reactor.advance(31 * ONE_DAY_IN_SECONDS) + # Also advance 5 minutes to let another user_daily_visits update occur + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # (Make sure the user isn't somehow counted by this point.) + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Send a message (this counts as activity) + self.helper.send(room_id, "message2", tok=access_token) + + # We have to wait a few minutes for the user_daily_visits table to + # be updated by a background process. + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user is counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance to JUST under 60 days after the user's first post + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5) + + # Check the user is still counted. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Advance into the next day. The user's first activity is now more than 60 days old. + self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5) + + # Check the user is now no longer counted in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_user_must_be_retained_for_at_least_a_month(self): + """ + Tests that a newly-registered user must be retained for a whole month + before appearing in the R30v2 statistic, even if they post every day + during that time! + """ + + # set a custom user-agent to impersonate Element/Android. + headers = ( + ( + "User-Agent", + "Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check the user does not contribute to R30 yet. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + for _ in range(30): + # This loop posts a message every day for 30 days + self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS) + self.helper.send( + room_id, "I'm still here", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Notice that the user *still* does not contribute to R30! + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # advance yet another day with more activity + self.reactor.advance(ONE_DAY_IN_SECONDS) + self.helper.send( + room_id, "Still here!", tok=access_token, custom_headers=headers + ) + + # give time for user_daily_visits to update + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # *Now* the user appears in R30. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0} + ) + + def test_r30v2_returning_dormant_users_not_counted(self): + """ + Tests that dormant users (users inactive for a long time) do not + contribute to R30v2 when they return for just a single day. + This is a key difference between R30 and R30v2. + """ + + # set a custom user-agent to impersonate Element/iOS. + headers = ( + ( + "User-Agent", + "Riot/1.4 (iPhone; iOS 13; Scale/4.00)", + ), + ) + + # Register a user and send a message + user_id = self.register_user("u1", "secret!") + access_token = self.login("u1", "secret!", custom_headers=headers) + room_id = self.helper.create_room_as( + room_creator=user_id, tok=access_token, custom_headers=headers + ) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # the user goes inactive for 2 months + self.reactor.advance(60 * ONE_DAY_IN_SECONDS) + + # the user returns for one day, perhaps just to check out a new feature + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # Give time for user_daily_visits table to be updated. + # (user_daily_visits is updated every 5 minutes using a looping call.) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + store = self.hs.get_datastore() + + # Check that the user does not contribute to R30v2, even though it's been + # more than 30 days since registration. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) + + # Check that this is a situation where old R30 differs: + # old R30 DOES count this as 'retained'. + r30_results = self.get_success(store.count_r30_users()) + self.assertEqual(r30_results, {"all": 1, "ios": 1}) + + # Now we want to check that the user will still be able to appear in + # R30v2 as long as the user performs some other activity between + # 30 and 60 days later. + self.reactor.advance(32 * ONE_DAY_IN_SECONDS) + self.helper.send(room_id, "message", tok=access_token, custom_headers=headers) + + # (give time for tables to update) + self.reactor.advance(FIVE_MINUTES_IN_SECONDS) + + # Check the user now satisfies the requirements to appear in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 59.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS) + + # Check the user still appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0} + ) + + # Advance to 60.5 days after the user's first R30v2-eligible activity. + self.reactor.advance(ONE_DAY_IN_SECONDS) + + # Check the user no longer appears in R30v2. + r30_results = self.get_success(store.count_r30v2_users()) + self.assertEqual( + r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 69798e95c3..fc2d35596e 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -19,7 +19,7 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Mapping, MutableMapping, Optional +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union from unittest.mock import patch import attr @@ -53,6 +53,9 @@ class RestHelper: tok: str = None, expect_code: int = 200, extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> str: """ Create a room. @@ -87,6 +90,7 @@ class RestHelper: "POST", path, json.dumps(content).encode("utf8"), + custom_headers=custom_headers, ) assert channel.result["code"] == b"%d" % expect_code, channel.result @@ -175,14 +179,30 @@ class RestHelper: self.auth_user_id = temp_id - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): if body is None: body = "body_text_here" content = {"msgtype": "m.text", "body": body} return self.send_event( - room_id, "m.room.message", content, txn_id, tok, expect_code + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, ) def send_event( @@ -193,6 +213,9 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -207,6 +230,7 @@ class RestHelper: "PUT", path, json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, ) assert ( diff --git a/tests/unittest.py b/tests/unittest.py index c6d9064423..3eec9c4d5b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase): user_id = channel.json_body["user_id"] return user_id - def login(self, username, password, device_id=None): + def login( + self, + username, + password, + device_id=None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): """ Log in a user, and get an access token. Requires the Login API be registered. @@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase): body["device_id"] = device_id channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") + "POST", + "/_matrix/client/r0/login", + json.dumps(body).encode("utf8"), + custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1 From a743bf46949e851c9a10d8e01a138659f3af2484 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 20 Jul 2021 12:39:46 +0200 Subject: Port the ThirdPartyEventRules module interface to the new generic interface (#10386) Port the third-party event rules interface to the generic module interface introduced in v1.37.0 --- changelog.d/10386.removal | 1 + docs/modules.md | 62 ++++++- docs/sample_config.yaml | 13 -- docs/upgrade.md | 13 ++ synapse/app/_base.py | 2 + synapse/config/third_party_event_rules.py | 15 -- synapse/events/third_party_rules.py | 245 +++++++++++++++++++++++----- synapse/handlers/federation.py | 4 +- synapse/handlers/message.py | 8 +- synapse/handlers/room.py | 10 +- synapse/module_api/__init__.py | 6 + tests/rest/client/test_third_party_rules.py | 132 ++++++++++++--- 12 files changed, 403 insertions(+), 108 deletions(-) create mode 100644 changelog.d/10386.removal (limited to 'tests') diff --git a/changelog.d/10386.removal b/changelog.d/10386.removal new file mode 100644 index 0000000000..800a6143d7 --- /dev/null +++ b/changelog.d/10386.removal @@ -0,0 +1 @@ +The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information. diff --git a/docs/modules.md b/docs/modules.md index c4cb7018f7..9a430390a4 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -186,7 +186,7 @@ The arguments passed to this callback are: ```python async def check_media_file_for_spam( file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper", - file_info: "synapse.rest.media.v1._base.FileInfo" + file_info: "synapse.rest.media.v1._base.FileInfo", ) -> bool ``` @@ -223,6 +223,66 @@ Called after successfully registering a user, in case the module needs to perfor operations to keep track of them. (e.g. add them to a database table). The user is represented by their Matrix user ID. +#### Third party rules callbacks + +Third party rules callbacks allow module developers to add extra checks to verify the +validity of incoming events. Third party event rules callbacks can be registered using +the module API's `register_third_party_rules_callbacks` method. + +The available third party rules callbacks are: + +```python +async def check_event_allowed( + event: "synapse.events.EventBase", + state_events: "synapse.types.StateMap", +) -> Tuple[bool, Optional[dict]] +``` + +** +This callback is very experimental and can and will break without notice. Module developers +are encouraged to implement `check_event_for_spam` from the spam checker category instead. +** + +Called when processing any incoming event, with the event and a `StateMap` +representing the current state of the room the event is being sent into. A `StateMap` is +a dictionary that maps tuples containing an event type and a state key to the +corresponding state event. For example retrieving the room's `m.room.create` event from +the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`. +The module must return a boolean indicating whether the event can be allowed. + +Note that this callback function processes incoming events coming via federation +traffic (on top of client traffic). This means denying an event might cause the local +copy of the room's history to diverge from that of remote servers. This may cause +federation issues in the room. It is strongly recommended to only deny events using this +callback function if the sender is a local user, or in a private federation in which all +servers are using the same module, with the same configuration. + +If the boolean returned by the module is `True`, it may also tell Synapse to replace the +event with new data by returning the new event's data as a dictionary. In order to do +that, it is recommended the module calls `event.get_dict()` to get the current event as a +dictionary, and modify the returned dictionary accordingly. + +Note that replacing the event only works for events sent by local users, not for events +received over federation. + +```python +async def on_create_room( + requester: "synapse.types.Requester", + request_content: dict, + is_requester_admin: bool, +) -> None +``` + +Called when processing a room creation request, with the `Requester` object for the user +performing the request, a dictionary representing the room creation request's JSON body +(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom) +for a list of possible parameters), and a boolean indicating whether the user performing +the request is a server admin. + +Modules can modify the `request_content` (by e.g. adding events to its `initial_state`), +or deny the room's creation by raising a `module_api.errors.SynapseError`. + + ### Porting an existing module that uses the old interface In order to port a module that uses Synapse's old module interface, its author needs to: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f4845a5841..853c2f6899 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2654,19 +2654,6 @@ stats: # action: allow -# Server admins can define a Python module that implements extra rules for -# allowing or denying incoming events. In order to work, this module needs to -# override the methods defined in synapse/events/third_party_rules.py. -# -# This feature is designed to be used in closed federations only, where each -# participating server enforces the same rules. -# -#third_party_event_rules: -# module: "my_custom_project.SuperRulesSet" -# config: -# example_option: 'things' - - ## Opentracing ## # These settings enable opentracing, which implements distributed tracing. diff --git a/docs/upgrade.md b/docs/upgrade.md index db0450f563..c8f4a2c171 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,6 +86,19 @@ process, for example: ``` +# Upgrading to v1.39.0 + +## Deprecation of the current third-party rules module interface + +The current third-party rules module interface is deprecated in favour of the new generic +modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer +to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface) +to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules) +to update their configuration once the modules they are using have been updated. + +We plan to remove support for the current third-party rules interface in September 2021. + + # Upgrading to v1.38.0 ## Re-indexing of `events` table on Postgres databases diff --git a/synapse/app/_base.py b/synapse/app/_base.py index b30571fe49..50a02f51f5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -368,6 +369,7 @@ async def start(hs: "HomeServer"): module(config=config, api=module_api) load_legacy_spam_checkers(hs) + load_legacy_third_party_event_rules(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index f502ff539e..a3fae02420 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config): self.third_party_event_rules = load_module( provider, ("third_party_event_rules",) ) - - def generate_config_section(self, **kwargs): - return """\ - # Server admins can define a Python module that implements extra rules for - # allowing or denying incoming events. In order to work, this module needs to - # override the methods defined in synapse/events/third_party_rules.py. - # - # This feature is designed to be used in closed federations only, where each - # participating server enforces the same rules. - # - #third_party_event_rules: - # module: "my_custom_project.SuperRulesSet" - # config: - # example_option: 'things' - """ diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index f7944fd834..7a6eb3e516 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -11,16 +11,124 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple -from typing import TYPE_CHECKING, Union - +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Requester, StateMap +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + + +CHECK_EVENT_ALLOWED_CALLBACK = Callable[ + [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]] +] +ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] +CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ + [str, str, StateMap[EventBase]], Awaitable[bool] +] +CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ + [str, StateMap[EventBase], str], Awaitable[bool] +] + + +def load_legacy_third_party_event_rules(hs: "HomeServer"): + """Wrapper that loads a third party event rules module configured using the old + configuration, and registers the hooks they implement. + """ + if hs.config.third_party_event_rules is None: + return + + module, config = hs.config.third_party_event_rules + + api = hs.get_module_api() + third_party_rules = module(config=config, module_api=api) + + # The known hooks. If a module implements a method which name appears in this set, + # we'll want to register it. + third_party_event_rules_methods = { + "check_event_allowed", + "on_create_room", + "check_threepid_can_be_invited", + "check_visibility_can_be_modified", + } + + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We return a separate wrapper for these methods because, in order to wrap them + # correctly, we need to await its result. Therefore it doesn't make a lot of + # sense to make it go through the run() wrapper. + if f.__name__ == "check_event_allowed": + + # We need to wrap check_event_allowed because its old form would return either + # a boolean or a dict, but now we want to return the dict separately from the + # boolean. + async def wrap_check_event_allowed( + event: EventBase, + state_events: StateMap[EventBase], + ) -> Tuple[bool, Optional[dict]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(event, state_events) + if isinstance(res, dict): + return True, res + else: + return res, None + + return wrap_check_event_allowed + + if f.__name__ == "on_create_room": + + # We need to wrap on_create_room because its old form would return a boolean + # if the room creation is denied, but now we just want it to raise an + # exception. + async def wrap_on_create_room( + requester: Requester, config: dict, is_requester_admin: bool + ) -> None: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + res = await f(requester, config, is_requester_admin) + if res is False: + raise SynapseError( + 403, + "Room creation forbidden with these parameters", + ) + + return wrap_on_create_room + + def run(*args, **kwargs): + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None + + return maybe_awaitable(f(*args, **kwargs)) + + return run + + # Register the hooks through the module API. + hooks = { + hook: async_wrapper(getattr(third_party_rules, hook, None)) + for hook in third_party_event_rules_methods + } + + api.register_third_party_rules_callbacks(**hooks) + class ThirdPartyEventRules: """Allows server admins to provide a Python module implementing an extra @@ -35,36 +143,65 @@ class ThirdPartyEventRules: self.store = hs.get_datastore() - module = None - config = None - if hs.config.third_party_event_rules: - module, config = hs.config.third_party_event_rules + self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] + self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] + self._check_threepid_can_be_invited_callbacks: List[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = [] + self._check_visibility_can_be_modified_callbacks: List[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = [] + + def register_third_party_rules_callbacks( + self, + check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, + on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, + check_threepid_can_be_invited: Optional[ + CHECK_THREEPID_CAN_BE_INVITED_CALLBACK + ] = None, + check_visibility_can_be_modified: Optional[ + CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK + ] = None, + ): + """Register callbacks from modules for each hook.""" + if check_event_allowed is not None: + self._check_event_allowed_callbacks.append(check_event_allowed) + + if on_create_room is not None: + self._on_create_room_callbacks.append(on_create_room) + + if check_threepid_can_be_invited is not None: + self._check_threepid_can_be_invited_callbacks.append( + check_threepid_can_be_invited, + ) - if module is not None: - self.third_party_rules = module( - config=config, - module_api=hs.get_module_api(), + if check_visibility_can_be_modified is not None: + self._check_visibility_can_be_modified_callbacks.append( + check_visibility_can_be_modified, ) async def check_event_allowed( self, event: EventBase, context: EventContext - ) -> Union[bool, dict]: + ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. The module can return: * True: the event is allowed. * False: the event is not allowed, and should be rejected with M_FORBIDDEN. - * a dict: replacement event data. + + If the event is allowed, the module can also return a dictionary to use as a + replacement for the event. Args: event: The event to be checked. context: The context of the event. Returns: - The result from the ThirdPartyRules module, as above + The result from the ThirdPartyRules module, as above. """ - if self.third_party_rules is None: - return True + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_event_allowed_callbacks) == 0: + return True, None prev_state_ids = await context.get_prev_state_ids() @@ -77,29 +214,46 @@ class ThirdPartyEventRules: # the hashes and signatures. event.freeze() - return await self.third_party_rules.check_event_allowed(event, state_events) + for callback in self._check_event_allowed_callbacks: + try: + res, replacement_data = await callback(event, state_events) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue + + # Return if the event shouldn't be allowed or if the module came up with a + # replacement dict for the event. + if res is False: + return res, None + elif isinstance(replacement_data, dict): + return True, replacement_data + + return True, None async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool - ) -> bool: - """Intercept requests to create room to allow, deny or update the - request config. + ) -> None: + """Intercept requests to create room to maybe deny it (via an exception) or + update the request config. Args: requester config: The creation config from the client. is_requester_admin: If the requester is an admin - - Returns: - Whether room creation is allowed or denied. """ - - if self.third_party_rules is None: - return True - - return await self.third_party_rules.on_create_room( - requester, config, is_requester_admin - ) + for callback in self._on_create_room_callbacks: + try: + await callback(requester, config, is_requester_admin) + except Exception as e: + # Don't silence the errors raised by this callback since we expect it to + # raise an exception to deny the creation of the room; instead make sure + # it's a SynapseError we can send to clients. + if not isinstance(e, SynapseError): + e = SynapseError( + 403, "Room creation forbidden with these parameters" + ) + + raise e async def check_threepid_can_be_invited( self, medium: str, address: str, room_id: str @@ -114,15 +268,20 @@ class ThirdPartyEventRules: Returns: True if the 3PID can be invited, False if not. """ - - if self.third_party_rules is None: + # Bail out early without hitting the store if we don't have any callbacks to run. + if len(self._check_threepid_can_be_invited_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events - ) + for callback in self._check_threepid_can_be_invited_callbacks: + try: + if await callback(medium, address, state_events) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def check_visibility_can_be_modified( self, room_id: str, new_visibility: str @@ -137,18 +296,20 @@ class ThirdPartyEventRules: Returns: True if the room's visibility can be modified, False if not. """ - if self.third_party_rules is None: - return True - - check_func = getattr( - self.third_party_rules, "check_visibility_can_be_modified", None - ) - if not check_func or not callable(check_func): + # Bail out early without hitting the store if we don't have any callback + if len(self._check_visibility_can_be_modified_callbacks) == 0: return True state_events = await self._get_state_map_for_room(room_id) - return await check_func(room_id, state_events, new_visibility) + for callback in self._check_visibility_can_be_modified_callbacks: + try: + if await callback(room_id, state_events, new_visibility) is False: + return False + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + + return True async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index cf389be3e4..5728719909 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler): builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler): # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c7fe4ff89e..8a0024ce84 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -949,10 +949,10 @@ class EventCreationHandler: if requester: context.app_service = requester.app_service - third_party_result = await self.third_party_event_rules.check_event_allowed( + res, new_content = await self.third_party_event_rules.check_event_allowed( event, context ) - if not third_party_result: + if res is False: logger.info( "Event %s forbidden by third-party rules", event, @@ -960,11 +960,11 @@ class EventCreationHandler: raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - elif isinstance(third_party_result, dict): + elif new_content is not None: # the third-party rules want to replace the event. We'll need to build a new # event. event, context = await self._rebuild_event_after_third_party_rules( - third_party_result, event + new_content, event ) self.validator.validate_new(event, self.config) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 64656fda22..370561e549 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler): else: is_requester_admin = await self.auth.is_server_admin(requester.user) - # Check whether the third party rules allows/changes the room create - # request. - event_allowed = await self.third_party_event_rules.on_create_room( + # Let the third party rules modify the room creation config if needed, or abort + # the room creation entirely with an exception. + await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) - if not event_allowed: - raise SynapseError( - 403, "You are not permitted to create rooms", Codes.FORBIDDEN - ) if not is_requester_admin and not await self.spam_checker.user_may_create_room( user_id diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 5df9349134..1259fc2d90 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -110,6 +110,7 @@ class ModuleApi: self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() + self._third_party_event_rules = hs.get_third_party_event_rules() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -124,6 +125,11 @@ class ModuleApi: """Registers callbacks for account validity capabilities.""" return self._account_validity_handler.register_account_validity_callbacks + @property + def register_third_party_rules_callbacks(self): + """Registers callbacks for third party event rules capabilities.""" + return self._third_party_event_rules.register_third_party_rules_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index c5e1c5458b..28dd47a28b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -16,17 +16,19 @@ from typing import Dict from unittest.mock import Mock from synapse.events import EventBase +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import Requester, StateMap +from synapse.util.frozenutils import unfreeze from tests import unittest thread_local = threading.local() -class ThirdPartyRulesTestModule: +class LegacyThirdPartyRulesTestModule: def __init__(self, config: Dict, module_api: ModuleApi): # keep a record of the "current" rules module, so that the test can patch # it if desired. @@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule: return config -def current_rules_module() -> ThirdPartyRulesTestModule: - return thread_local.rules_module +class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return False + + +class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + d = event.get_dict() + content = unfreeze(event.content) + content["foo"] = "bar" + d["content"] = content + return d class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): @@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def default_config(self): - config = super().default_config() - config["third_party_event_rules"] = { - "module": __name__ + ".ThirdPartyRulesTestModule", - "config": {}, - } - return config + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + load_legacy_third_party_event_rules(hs) + + return hs def prepare(self, reactor, clock, homeserver): # Create a user and room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.tok = self.login("kermit", "monkey") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + # Some tests might prevent room creation on purpose. + try: + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + except Exception: + pass def test_third_party_rules(self): """Tests that a forbidden event is forbidden from being sent, but an allowed one @@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # patch the rules module with a Mock which will return False for some event # types async def check(ev, state): - return ev.type != "foo.bar.forbidden" + return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) - current_rules_module().check_event_allowed = callback + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [ + callback + ] channel = self.make_request( "PUT", @@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # first patch the event checker so that it will try to modify the event async def check(ev: EventBase, state): ev.content = {"x": "y"} - return True + return True, None - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"500", channel.result) + # check_event_allowed has some error handling, so it shouldn't 500 just because a + # module did something bad. + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["x"], "x") def test_modify_event(self): """The module can return a modified version of the event""" @@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): async def check(ev: EventBase, state): d = ev.get_dict() d["content"] = {"x": "y"} - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "msgtype": "m.text", "body": d["content"]["body"].upper(), } - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # Send an event, then edit it. channel = self.make_request( @@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.assertEqual(ev["content"]["body"], "EDITED BODY") def test_send_event(self): - """Tests that the module can send an event into a room via the module api""" + """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", "body": "Hello!", @@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "sender": self.user_id, } event: EventBase = self.get_success( - current_rules_module().module_api.create_and_send_event_into_room( - event_dict - ) + self.hs.get_module_api().create_and_send_event_into_room(event_dict) ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) self.assertEquals(event.type, "m.room.message") self.assertEquals(event.content, content) + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyChangeEvents", + "config": {}, + } + } + ) + def test_legacy_check_event_allowed(self): + """Tests that the wrapper for legacy check_event_allowed callbacks works + correctly. + """ + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + self.assertIn("foo", channel.json_body["content"].keys()) + self.assertEqual(channel.json_body["content"]["foo"], "bar") + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyDenyNewRooms", + "config": {}, + } + } + ) + def test_legacy_on_create_room(self): + """Tests that the wrapper for legacy on_create_room callbacks works + correctly. + """ + self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) -- cgit 1.5.1 From 54389d5697622f1beffaeda96d9c6da7ef7d93a9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 20 Jul 2021 14:24:25 +0100 Subject: Fix dropping locks on shut down (#10433) --- changelog.d/10433.bugfix | 1 + synapse/storage/databases/main/lock.py | 6 +++++- tests/storage/databases/main/test_lock.py | 13 +++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10433.bugfix (limited to 'tests') diff --git a/changelog.d/10433.bugfix b/changelog.d/10433.bugfix new file mode 100644 index 0000000000..ad85a83f96 --- /dev/null +++ b/changelog.d/10433.bugfix @@ -0,0 +1 @@ +Fix error while dropping locks on shutdown. Introduced in v1.38.0. diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 774861074c..3d1dff660b 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -78,7 +78,11 @@ class LockStore(SQLBaseStore): """Called when the server is shutting down""" logger.info("Dropping held locks due to shutdown") - for (lock_name, lock_key), token in self._live_tokens.items(): + # We need to take a copy of the tokens dict as dropping the locks will + # cause the dictionary to change. + tokens = dict(self._live_tokens) + + for (lock_name, lock_key), token in tokens.items(): await self._drop_lock(lock_name, lock_key, token) logger.info("Dropped locks due to shutdown") diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 9ca70e7367..d326a1d6a6 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -98,3 +98,16 @@ class LockTestCase(unittest.HomeserverTestCase): lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) self.assertIsNotNone(lock2) + + def test_shutdown(self): + """Test that shutting down Synapse releases the locks""" + # Acquire two locks + lock = self.get_success(self.store.try_acquire_lock("name", "key1")) + self.assertIsNotNone(lock) + lock2 = self.get_success(self.store.try_acquire_lock("name", "key2")) + self.assertIsNotNone(lock2) + + # Now call the shutdown code + self.get_success(self.store._on_shutdown()) + + self.assertEqual(self.store._live_tokens, {}) -- cgit 1.5.1 From 74d09a43d9e0f65f1292aa51f58ea676e4aefc7f Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 27 Jul 2021 14:36:38 +0100 Subject: Always communicate device OTK counts to clients (#10485) Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- changelog.d/10485.bugfix | 1 + synapse/api/constants.py | 8 ++++++++ synapse/handlers/sync.py | 4 ++++ synapse/storage/databases/main/end_to_end_keys.py | 9 ++++++++- tests/handlers/test_e2e_keys.py | 20 +++++++++++++++----- 5 files changed, 36 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10485.bugfix (limited to 'tests') diff --git a/changelog.d/10485.bugfix b/changelog.d/10485.bugfix new file mode 100644 index 0000000000..9b44006dc0 --- /dev/null +++ b/changelog.d/10485.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would not inform clients that a device had exhausted its one-time-key pool, potentially causing problems decrypting events. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8363c2bb0f..8c7ad2a407 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -127,6 +127,14 @@ class ToDeviceEventTypes: RoomKeyRequest = "m.room_key_request" +class DeviceKeyAlgorithms: + """Spec'd algorithms for the generation of per-device keys""" + + ED25519 = "ed25519" + CURVE25519 = "curve25519" + SIGNED_CURVE25519 = "signed_curve25519" + + class EduTypes: Presence = "m.presence" diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 150a4f291e..f30bfcc93c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1093,6 +1093,10 @@ class SyncHandler: one_time_key_counts: JsonDict = {} unused_fallback_key_types: List[str] = [] if device_id: + # TODO: We should have a way to let clients differentiate between the states of: + # * no change in OTK count since the provided since token + # * the server has zero OTKs left for this device + # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 78ae68ec68..1edc96042b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection +from synapse.api.constants import DeviceKeyAlgorithms from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, make_in_list_sql_clause @@ -381,9 +382,15 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): " GROUP BY algorithm" ) txn.execute(sql, (user_id, device_id)) - result = {} + + # Initially set the key count to 0. This ensures that the client will always + # receive *some count*, even if it's 0. + result = {DeviceKeyAlgorithms.SIGNED_CURVE25519: 0} + + # Override entries with the count of any keys we pulled from the database for algorithm, key_count in txn: result[algorithm] = key_count + return result return await self.db_pool.runInteraction( diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index e0a24824cc..39e7b1ab25 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -47,12 +47,16 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): "alg2:k3": {"key": "key3"}, } + # Note that "signed_curve25519" is always returned in key count responses. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} + ) # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" @@ -61,7 +65,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} + ) def test_change_one_time_keys(self): """attempts to change one-time-keys should be rejected""" @@ -79,7 +85,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} + ) # Error when changing string key self.get_failure( @@ -89,7 +97,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): SynapseError, ) - # Error when replacing dict key with strin + # Error when replacing dict key with string self.get_failure( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} @@ -131,7 +139,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): local_user, device_id, {"one_time_keys": keys} ) ) - self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}} + ) res2 = self.get_success( self.handler.claim_one_time_keys( -- cgit 1.5.1