From ef345c5a7b544aafa9c37bc2c4f626dfcef529f9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 10 Jun 2020 16:21:16 +0100 Subject: Add a new unread_counter to sync responses --- synapse/push/push_tools.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse/push') diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 5dae4648c0..9f264ca4a4 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,7 +39,10 @@ def get_badge_count(store, user_id): ) # return one badge count per conversation, as count per # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + # We're populating this badge using the unread_count (instead of the + # notify_count) as this badge is the number of missed messages, not the + # number of missed notifications. + badge += 1 if notifs["unread_count"] else 0 return badge -- cgit 1.5.1 From ea8f6e611bdc4c2ee3f6fea76893650ba8f0facd Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 11 Jun 2020 15:30:42 +0100 Subject: Actually act on mark_unread --- synapse/push/bulk_push_rule_evaluator.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) (limited to 'synapse/push') diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index e75d964ac8..f7c3db5828 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -191,9 +191,13 @@ class BulkPushRuleEvaluator(object): ) if matches: actions = [x for x in rule["actions"] if x != "dont_notify"] - if actions and "notify" in actions: - # Push rules say we should notify the user of this event - actions_by_user[uid] = actions + if actions: + if ( + "notify" in actions + or "org.matrix.msc2625.mark_unread" in actions + ): + # Push rules say we should act on this event. + actions_by_user[uid] = actions break # Mark in the DB staging area the push actions for users who should be -- cgit 1.5.1 From e47e5a2dcd2e7210c3830c3f0b8420a8b0988133 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 12 Jun 2020 15:11:01 +0100 Subject: Incorporate review bits --- changelog.d/7673.feature | 2 +- synapse/push/bulk_push_rule_evaluator.py | 13 +++++------ .../storage/data_stores/main/event_push_actions.py | 27 +++++++++++----------- 3 files changed, 20 insertions(+), 22 deletions(-) (limited to 'synapse/push') diff --git a/changelog.d/7673.feature b/changelog.d/7673.feature index 74e2059ade..ecc3ffd8d5 100644 --- a/changelog.d/7673.feature +++ b/changelog.d/7673.feature @@ -1 +1 @@ -Add a per-room counter for unread messages in responses to `/sync` requests. +Add a per-room counter for unread messages in responses to `/sync` requests. Implements [MSC2625](https://github.com/matrix-org/matrix-doc/pull/2625). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f7c3db5828..3244d39c37 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -191,13 +191,12 @@ class BulkPushRuleEvaluator(object): ) if matches: actions = [x for x in rule["actions"] if x != "dont_notify"] - if actions: - if ( - "notify" in actions - or "org.matrix.msc2625.mark_unread" in actions - ): - # Push rules say we should act on this event. - actions_by_user[uid] = actions + if ( + "notify" in actions + or "org.matrix.msc2625.mark_unread" in actions + ): + # Push rules say we should act on this event. + actions_by_user[uid] = actions break # Mark in the DB staging area the push actions for users who should be diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 688aef4d2f..4409e87913 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Tuple import attr from six import iteritems @@ -857,11 +858,11 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, - coalesce(old.%s, 0) + upd.%s, + coalesce(old.%s, 0) + upd.cnt, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as %s, + SELECT user_id, room_id, count(*) as cnt, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? @@ -874,31 +875,29 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # First get the count of unread messages. txn.execute( - sql % ("unread_count", "unread_count", "unread_count", ""), + sql % ("unread_count", ""), (old_rotate_stream_ordering, rotate_to_stream_ordering), ) - unread_rows = txn.fetchall() - - # Then get the count of notifications. - txn.execute( - sql % ("notif_count", "notif_count", "notif_count", "AND notif = 1"), - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - notif_rows = txn.fetchall() # We need to merge both lists into a single object because we might not have the # same amount of rows in each of them. In this case we use a dict indexed on the # user ID and room ID to make it easier to populate. - summaries = {} - for row in unread_rows: + summaries = {} # type: Dict[Tuple[str, str], EventPushSummary] + for row in txn: summaries[(row[0], row[1])] = EventPushSummary( user_id=row[0], room_id=row[1], unread_count=row[2], stream_ordering=row[3], old_user_id=row[4], notif_count=0, ) + # Then get the count of notifications. + txn.execute( + sql % ("notif_count", "AND notif = 1"), + (old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + # notif_rows is populated based on a subset of the query used to populate # unread_rows, so we can be sure that there will be no KeyError here. - for row in notif_rows: + for row in txn: summaries[(row[0], row[1])].notif_count = row[2] logger.info("Rotating notifications, handling %d rows", len(summaries)) -- cgit 1.5.1 From bd6dc17221741d4ceae05ae769a70696ae939336 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Jun 2020 07:03:36 -0400 Subject: Replace iteritems/itervalues/iterkeys with native versions. (#7692) --- changelog.d/7692.misc | 1 + synapse/api/auth.py | 4 +-- synapse/api/errors.py | 3 +- synapse/app/homeserver.py | 6 ++-- synapse/events/__init__.py | 4 +-- synapse/events/snapshot.py | 4 +-- synapse/federation/federation_base.py | 4 +-- synapse/federation/federation_server.py | 7 ++-- synapse/federation/send_queue.py | 8 ++--- synapse/federation/sender/__init__.py | 4 +-- synapse/handlers/appservice.py | 4 +-- synapse/handlers/device.py | 14 ++++---- synapse/handlers/e2e_keys.py | 14 ++++---- synapse/handlers/e2e_room_keys.py | 6 ++-- synapse/handlers/federation.py | 22 ++++++------- synapse/handlers/groups_local.py | 4 +-- synapse/handlers/message.py | 8 ++--- synapse/handlers/pagination.py | 4 +-- synapse/handlers/presence.py | 14 ++++---- synapse/handlers/room.py | 6 ++-- synapse/handlers/room_list.py | 4 +-- synapse/handlers/sync.py | 34 +++++++++---------- synapse/handlers/user_directory.py | 6 ++-- synapse/metrics/__init__.py | 6 ++-- synapse/push/bulk_push_rule_evaluator.py | 14 ++++---- synapse/rest/media/v1/media_repository.py | 4 +-- synapse/server_notices/consent_server_notices.py | 4 +-- .../resource_limits_server_notices.py | 4 +-- synapse/state/__init__.py | 22 ++++++------- synapse/state/v1.py | 26 +++++++-------- synapse/state/v2.py | 6 ++-- synapse/storage/data_stores/main/client_ips.py | 6 ++-- synapse/storage/data_stores/main/devices.py | 8 ++--- .../storage/data_stores/main/end_to_end_keys.py | 6 ++-- .../storage/data_stores/main/event_push_actions.py | 4 +-- synapse/storage/data_stores/main/events.py | 22 ++++++------- synapse/storage/data_stores/main/registration.py | 4 +-- synapse/storage/data_stores/main/roommember.py | 10 +++--- synapse/storage/data_stores/state/bg_updates.py | 6 ++-- synapse/storage/data_stores/state/store.py | 17 +++++----- synapse/storage/database.py | 13 ++++---- synapse/storage/persist_events.py | 7 ++-- synapse/storage/state.py | 38 ++++++++++------------ synapse/util/caches/descriptors.py | 4 +-- synapse/util/caches/expiringcache.py | 6 ++-- synapse/util/caches/treecache.py | 4 +-- synapse/visibility.py | 21 +++++------- 47 files changed, 184 insertions(+), 263 deletions(-) create mode 100644 changelog.d/7692.misc (limited to 'synapse/push') diff --git a/changelog.d/7692.misc b/changelog.d/7692.misc new file mode 100644 index 0000000000..ef6cbe0005 --- /dev/null +++ b/changelog.d/7692.misc @@ -0,0 +1 @@ +Replace uses of `six.iterkeys`/`iteritems`/`itervalues` with `keys()`/`items()`/`values()`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 06ade25674..06ba6604f3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -16,8 +16,6 @@ import logging from typing import Optional -from six import itervalues - import pymacaroons from netaddr import IPAddress @@ -90,7 +88,7 @@ class Auth(object): event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] event_auth.check( diff --git a/synapse/api/errors.py b/synapse/api/errors.py index d54dfb385d..a07a54580d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -19,7 +19,6 @@ import logging from typing import Dict, List -from six import iteritems from six.moves import http_client from canonicaljson import json @@ -497,7 +496,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs): A dict representing the error response JSON. """ err = {"error": msg, "errcode": code} - for key, value in iteritems(kwargs): + for key, value in kwargs.items(): err[key] = value return err diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8454d74858..93bc45208e 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -24,8 +24,6 @@ import os import resource import sys -from six import iteritems - from prometheus_client import Gauge from twisted.application import service @@ -525,7 +523,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): stats["total_nonbridged_users"] = total_nonbridged_users daily_user_type_results = yield hs.get_datastore().count_daily_user_type() - for name, count in iteritems(daily_user_type_results): + for name, count in daily_user_type_results.items(): stats["daily_user_type_" + name] = count room_count = yield hs.get_datastore().get_room_count() @@ -537,7 +535,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() r30_results = yield hs.get_datastore().count_r30_users() - for name, count in iteritems(r30_results): + for name, count in r30_results.items(): stats["r30_users_" + name] = count daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 533ba327f5..cc5deca75b 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -20,8 +20,6 @@ import os from distutils.util import strtobool from typing import Dict, Optional, Type -import six - from unpaddedbase64 import encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions @@ -290,7 +288,7 @@ class EventBase(metaclass=abc.ABCMeta): return list(self._dict.items()) def keys(self): - return six.iterkeys(self._dict) + return self._dict.keys() def prev_event_ids(self): """Returns the list of prev event IDs. The order matches the order diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 7c5f620d09..f94cdcbaba 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -14,8 +14,6 @@ # limitations under the License. from typing import Optional, Union -from six import iteritems - import attr from frozendict import frozendict @@ -341,7 +339,7 @@ def _encode_state_dict(state_dict): if state_dict is None: return None - return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)] + return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()] def _decode_state_dict(input): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c0012c6872..b2ab5bd6a4 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -93,8 +93,8 @@ class FederationBase(object): # *actual* redacted copy to be on the safe side.) redacted_event = prune_event(pdu) if set(redacted_event.keys()) == set(pdu.keys()) and set( - six.iterkeys(redacted_event.content) - ) == set(six.iterkeys(pdu.content)): + redacted_event.content.keys() + ) == set(pdu.content.keys()): logger.info( "Event %s seems to have been redacted; using our redacted " "copy", diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 32a8a2ee46..6920c23723 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -18,7 +18,6 @@ import logging from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union import six -from six import iteritems from canonicaljson import json from prometheus_client import Counter @@ -534,9 +533,9 @@ class FederationServer(FederationBase): ",".join( ( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) + for user_id, user_keys in json_result.items() + for device_id, device_keys in user_keys.items() + for key_id, _ in device_keys.items() ) ), ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 52f4f54215..6bbd762681 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -33,8 +33,6 @@ import logging from collections import namedtuple from typing import Dict, List, Tuple, Type -from six import iteritems - from sortedcontainers import SortedDict from twisted.internet import defer @@ -327,7 +325,7 @@ class FederationRemoteSendQueue(object): # stream position. keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} - for ((destination, edu_key), pos) in iteritems(keyed_edus): + for ((destination, edu_key), pos) in keyed_edus.items(): rows.append( ( pos, @@ -530,10 +528,10 @@ def process_rows_for_federation(transaction_queue, rows): states=[state], destinations=destinations ) - for destination, edu_map in iteritems(buff.keyed_edus): + for destination, edu_map in buff.keyed_edus.items(): for key, edu in edu_map.items(): transaction_queue.send_edu(edu, key) - for destination, edu_list in iteritems(buff.edus): + for destination, edu_list in buff.edus.items(): for edu in edu_list: transaction_queue.send_edu(edu, None) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index d473576902..5b8faea4e7 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -16,8 +16,6 @@ import logging from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple -from six import itervalues - from prometheus_client import Counter from twisted.internet import defer @@ -218,7 +216,7 @@ class FederationSender(object): defer.gatherResults( [ run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) + for evs in events_by_room.values() ], consumeErrors=True, ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index fe62f78e67..ac1b64caff 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -15,8 +15,6 @@ import logging -from six import itervalues - from prometheus_client import Counter from twisted.internet import defer @@ -125,7 +123,7 @@ class ApplicationServicesHandler(object): defer.gatherResults( [ run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) + for evs in events_by_room.values() ], consumeErrors=True, ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 230d170258..83f8fa1180 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -17,8 +17,6 @@ import logging from typing import Any, Dict, Optional -from six import iteritems, itervalues - from twisted.internet import defer from synapse.api import errors @@ -159,7 +157,7 @@ class DeviceWorkerHandler(BaseHandler): # The user may have left the room # TODO: Check if they actually did or if we were just invited. if room_id not in room_ids: - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -182,7 +180,7 @@ class DeviceWorkerHandler(BaseHandler): log_kv( {"event": "encountered empty previous state", "room_id": room_id} ) - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -198,10 +196,10 @@ class DeviceWorkerHandler(BaseHandler): # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. - for state_dict in itervalues(prev_state_ids): + for state_dict in prev_state_ids.values(): member_event = state_dict.get((EventTypes.Member, user_id), None) if not member_event or member_event != current_member_id: - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -211,14 +209,14 @@ class DeviceWorkerHandler(BaseHandler): # If there has been any change in membership, include them in the # possibly changed list. We'll check if they are joined below, # and we're not toooo worried about spuriously adding users. - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue # check if this member has changed since any of the extremities # at the stream_ordering, and add them to the list if so. - for state_dict in itervalues(prev_state_ids): + for state_dict in prev_state_ids.values(): prev_event_id = state_dict.get(key, None) if not prev_event_id or prev_event_id != event_id: if state_key != user_id: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 774a252619..a7e60cbc26 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -17,8 +17,6 @@ import logging -from six import iteritems - import attr from canonicaljson import encode_canonical_json, json from signedjson.key import decode_verify_key_bytes @@ -135,7 +133,7 @@ class E2eKeysHandler(object): remote_queries_not_in_cache = {} if remote_queries: query_list = [] - for user_id, device_ids in iteritems(remote_queries): + for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) else: @@ -145,9 +143,9 @@ class E2eKeysHandler(object): user_ids_not_in_cache, remote_results, ) = yield self.store.get_user_devices_from_cache(query_list) - for user_id, devices in iteritems(remote_results): + for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) - for device_id, device in iteritems(devices): + for device_id, device in devices.items(): keys = device.get("keys", None) device_display_name = device.get("device_display_name", None) if keys: @@ -446,9 +444,9 @@ class E2eKeysHandler(object): ",".join( ( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) + for user_id, user_keys in json_result.items() + for device_id, device_keys in user_keys.items() + for key_id, _ in device_keys.items() ) ), ) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 9abaf13b8f..2efea801bc 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -16,8 +16,6 @@ import logging -from six import iteritems - from twisted.internet import defer from synapse.api.errors import ( @@ -205,8 +203,8 @@ class E2eRoomKeysHandler(object): ) to_insert = [] # batch the inserts together changed = False # if anything has changed, we need to update the etag - for room_id, room in iteritems(room_keys["rooms"]): - for session_id, room_key in iteritems(room["sessions"]): + for room_id, room in room_keys["rooms"].items(): + for session_id, room_key in room["sessions"].items(): if not isinstance(room_key["is_verified"], bool): msg = ( "is_verified must be a boolean in keys for session %s in" diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b30f41dc4b..d6038d9995 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,8 +21,6 @@ import itertools import logging from typing import Dict, Iterable, List, Optional, Sequence, Tuple -import six -from six import iteritems, itervalues from six.moves import http_client, zip import attr @@ -398,7 +396,7 @@ class FederationHandler(BaseHandler): ) event_map.update(evs) - state = [event_map[e] for e in six.itervalues(state_map)] + state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "[%s %s] Error attempting to resolve state at missing " @@ -1009,7 +1007,7 @@ class FederationHandler(BaseHandler): """ joined_users = [ (state_key, int(event.depth)) - for (e_type, state_key), event in iteritems(state) + for (e_type, state_key), event in state.items() if e_type == EventTypes.Member and event.membership == Membership.JOIN ] @@ -1099,16 +1097,16 @@ class FederationHandler(BaseHandler): states = dict(zip(event_ids, [s.state for s in states])) state_map = await self.store.get_events( - [e_id for ids in itervalues(states) for e_id in itervalues(ids)], + [e_id for ids in states.values() for e_id in ids.values()], get_prev_content=False, ) states = { key: { k: state_map[e_id] - for k, e_id in iteritems(state_dict) + for k, e_id in state_dict.items() if e_id in state_map } - for key, state_dict in iteritems(states) + for key, state_dict in states.items() } for e_id, _ in sorted_extremeties_tuple: @@ -1733,7 +1731,7 @@ class FederationHandler(BaseHandler): state_groups = await self.state_store.get_state_groups(room_id, [event_id]) if state_groups: - _, state = list(iteritems(state_groups)).pop() + _, state = list(state_groups.items()).pop() results = {(e.type, e.state_key): e for e in state} if event.is_state(): @@ -2096,7 +2094,7 @@ class FederationHandler(BaseHandler): room_version, state_sets, event ) current_state_ids = { - k: e.event_id for k, e in iteritems(current_state_ids) + k: e.event_id for k, e in current_state_ids.items() } else: current_state_ids = await self.state_handler.get_current_state_ids( @@ -2112,7 +2110,7 @@ class FederationHandler(BaseHandler): # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ - e for k, e in iteritems(current_state_ids) if k in auth_types + e for k, e in current_state_ids.items() if k in auth_types ] current_auth_events = await self.store.get_events(current_state_ids) @@ -2428,7 +2426,7 @@ class FederationHandler(BaseHandler): else: event_key = None state_updates = { - k: a.event_id for k, a in iteritems(auth_events) if k != event_key + k: a.event_id for k, a in auth_events.items() if k != event_key } current_state_ids = await context.get_current_state_ids() @@ -2439,7 +2437,7 @@ class FederationHandler(BaseHandler): prev_state_ids = await context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) - prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) + prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) # create a new state group as a delta from the existing one. prev_group = context.state_group diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ebe8d25bd8..7cb106e365 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -16,8 +16,6 @@ import logging -from six import iteritems - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.types import get_domain_from_id @@ -227,7 +225,7 @@ class GroupsLocalWorkerHandler(object): results = {} failed_results = [] - for destination, dest_user_ids in iteritems(destinations): + for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( destination, list(dest_user_ids) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 649ca1f08a..354da9a3b5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -17,7 +17,7 @@ import logging from typing import Optional, Tuple -from six import iteritems, itervalues, string_types +from six import string_types from canonicaljson import encode_canonical_json, json @@ -246,7 +246,7 @@ class MessageHandler(object): "avatar_url": profile.avatar_url, "display_name": profile.display_name, } - for user_id, profile in iteritems(users_with_profile) + for user_id, profile in users_with_profile.items() } def maybe_schedule_expiry(self, event): @@ -988,7 +988,7 @@ class EventCreationHandler(object): state_to_include_ids = [ e_id - for k, e_id in iteritems(current_state_ids) + for k, e_id in current_state_ids.items() if k[0] in self.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -1002,7 +1002,7 @@ class EventCreationHandler(object): "content": e.content, "sender": e.sender, } - for e in itervalues(state_to_include) + for e in state_to_include.values() ] invitee = UserID.from_string(event.state_key) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index d7442c62a7..7fbc229502 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -from six import iteritems - from twisted.internet import defer from twisted.python.failure import Failure @@ -145,7 +143,7 @@ class PaginationHandler(object): logger.debug("[purge] Rooms to purge: %s", rooms) - for room_id, retention_policy in iteritems(rooms): + for room_id, retention_policy in rooms.items(): logger.info("[purge] Attempting to purge messages in room %s", room_id) if room_id in self._purges_in_progress_by_room: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3594f3b00f..2e8914be14 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -27,8 +27,6 @@ import logging from contextlib import contextmanager from typing import Dict, Iterable, List, Set -from six import iteritems, itervalues - from prometheus_client import Counter from typing_extensions import ContextManager @@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC): for user_id in user_ids } - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) for user_id in missing @@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) - for prev_state in itervalues(prev_states) + for prev_state in prev_states.values() ] ) self.external_process_last_updated_ms.pop(process_id, None) @@ -1087,7 +1085,7 @@ class PresenceEventSource(object): return (list(updates.values()), max_token) else: return ( - [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE], + [s for s in updates.values() if s.state != PresenceState.OFFLINE], max_token, ) @@ -1323,11 +1321,11 @@ def get_interested_remotes(store, states, state_handler): # hosts in those rooms. room_ids_to_states, users_to_states = yield get_interested_parties(store, states) - for room_id, states in iteritems(room_ids_to_states): + for room_id, states in room_ids_to_states.items(): hosts = yield state_handler.get_current_hosts_in_room(room_id) hosts_and_states.append((hosts, states)) - for user_id, states in iteritems(users_to_states): + for user_id, states in users_to_states.items(): host = get_domain_from_id(user_id) hosts_and_states.append(([host], states)) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 46c2739143..f7401373ca 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -24,7 +24,7 @@ import string from collections import OrderedDict from typing import Tuple -from six import iteritems, string_types +from six import string_types from synapse.api.constants import ( EventTypes, @@ -377,7 +377,7 @@ class RoomCreationHandler(BaseHandler): # map from event_id to BaseEvent old_room_state_events = await self.store.get_events(old_room_state_ids.values()) - for k, old_event_id in iteritems(old_room_state_ids): + for k, old_event_id in old_room_state_ids.items(): old_event = old_room_state_events.get(old_event_id) if old_event: initial_state[k] = old_event.content @@ -430,7 +430,7 @@ class RoomCreationHandler(BaseHandler): old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) - for k, old_event in iteritems(old_room_member_state_events): + for k, old_event in old_room_member_state_events.items(): # Only transfer ban events if ( "membership" in old_event.content diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 4cbc02b0d0..5e05be6181 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Any, Dict, Optional -from six import iteritems - import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -271,7 +269,7 @@ class RoomListHandler(BaseHandler): event_map = yield self.store.get_events( [ event_id - for key, event_id in iteritems(current_state_ids) + for key, event_id in current_state_ids.items() if key[0] in ( EventTypes.Create, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6bdb24baff..4c7524493e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -18,8 +18,6 @@ import itertools import logging from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple -from six import iteritems, itervalues - import attr from prometheus_client import Counter @@ -390,7 +388,7 @@ class SyncHandler(object): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} + event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) receipt_key = since_token.receipt_key if since_token else "0" @@ -408,7 +406,7 @@ class SyncHandler(object): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} + event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) return now_token, ephemeral_by_room @@ -454,7 +452,7 @@ class SyncHandler(object): current_state_ids_map = await self.state.get_current_state_ids( room_id ) - current_state_ids = frozenset(itervalues(current_state_ids_map)) + current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( self.storage, @@ -509,7 +507,7 @@ class SyncHandler(object): current_state_ids_map = await self.state.get_current_state_ids( room_id ) - current_state_ids = frozenset(itervalues(current_state_ids_map)) + current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( self.storage, @@ -909,7 +907,7 @@ class SyncHandler(object): logger.debug("filtering state from %r...", state_ids) state_ids = { t: event_id - for t, event_id in iteritems(state_ids) + for t, event_id in state_ids.items() if cache.get(t[1]) != event_id } logger.debug("...to %r", state_ids) @@ -1430,7 +1428,7 @@ class SyncHandler(object): if since_token: for joined_sync in sync_result_builder.joined: it = itertools.chain( - joined_sync.timeline.events, itervalues(joined_sync.state) + joined_sync.timeline.events, joined_sync.state.values() ) for event in it: if event.type == EventTypes.Member: @@ -1505,7 +1503,7 @@ class SyncHandler(object): newly_left_rooms = [] room_entries = [] invited = [] - for room_id, events in iteritems(mem_change_events_by_room_id): + for room_id, events in mem_change_events_by_room_id.items(): logger.debug( "Membership changes in %s: [%s]", room_id, @@ -1993,17 +1991,17 @@ def _calculate_state( event_id_to_key = { e: key for key, e in itertools.chain( - iteritems(timeline_contains), - iteritems(previous), - iteritems(timeline_start), - iteritems(current), + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(itervalues(current)) - ts_ids = set(itervalues(timeline_start)) - p_ids = set(itervalues(previous)) - tc_ids = set(itervalues(timeline_contains)) + c_ids = set(current.values()) + ts_ids = set(timeline_start.values()) + p_ids = set(previous.values()) + tc_ids = set(timeline_contains.values()) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, @@ -2017,7 +2015,7 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member + e for t, e in timeline_start.items() if t[0] == EventTypes.Member ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 12423b909a..521b6d620d 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -15,8 +15,6 @@ import logging -from six import iteritems, iterkeys - import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler @@ -289,7 +287,7 @@ class UserDirectoryHandler(StateDeltasHandler): users_with_profile = await self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. - for user_id in iterkeys(users_with_profile): + for user_id in users_with_profile.keys(): await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. @@ -298,7 +296,7 @@ class UserDirectoryHandler(StateDeltasHandler): # which when ran over an entire room, will result in the same values # being added multiple times. The batching upserts shouldn't make this # too bad, though. - for user_id, profile in iteritems(users_with_profile): + for user_id, profile in users_with_profile.items(): await self._handle_new_user(room_id, user_id, profile) async def _handle_new_user(self, room_id, user_id, profile): diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 9cf31f96b3..087a49d65d 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -22,8 +22,6 @@ import threading import time from typing import Callable, Dict, Iterable, Optional, Tuple, Union -import six - import attr from prometheus_client import Counter, Gauge, Histogram from prometheus_client.core import ( @@ -83,7 +81,7 @@ class LaterGauge(object): return if isinstance(calls, dict): - for k, v in six.iteritems(calls): + for k, v in calls.items(): g.add_metric(k, v) else: g.add_metric([], calls) @@ -194,7 +192,7 @@ class InFlightGauge(object): gauge = GaugeMetricFamily( "_".join([self.name, name]), "", labels=self.labels ) - for key, metrics in six.iteritems(metrics_by_key): + for key, metrics in metrics_by_key.items(): gauge.add_metric(key, getattr(metrics, name)) yield gauge diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index e75d964ac8..43ffe6faf0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple -from six import iteritems, itervalues - from prometheus_client import Counter from twisted.internet import defer @@ -130,7 +128,7 @@ class BulkPushRuleEvaluator(object): event, prev_state_ids, for_verification=False ) auth_events = yield self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} + auth_events = {(e.type, e.state_key): e for e in auth_events.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -162,7 +160,7 @@ class BulkPushRuleEvaluator(object): condition_cache = {} - for uid, rules in iteritems(rules_by_user): + for uid, rules in rules_by_user.items(): if event.sender == uid: continue @@ -395,7 +393,7 @@ class RulesForRoom(object): # If the event is a join event then it will be in current state evnts # map but not in the DB, so we have to explicitly insert it. if event.type == EventTypes.Member: - for event_id in itervalues(member_event_ids): + for event_id in member_event_ids.values(): if event_id == event.event_id: members[event_id] = (event.state_key, event.membership) @@ -404,7 +402,7 @@ class RulesForRoom(object): interested_in_user_ids = { user_id - for user_id, membership in itervalues(members) + for user_id, membership in members.values() if membership == Membership.JOIN } @@ -415,7 +413,7 @@ class RulesForRoom(object): ) user_ids = { - uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher + uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher } logger.debug("With pushers: %r", user_ids) @@ -436,7 +434,7 @@ class RulesForRoom(object): ) ret_rules_by_user.update( - item for item in iteritems(rules_by_user) if item[0] is not None + item for item in rules_by_user.items() if item[0] is not None ) self.update_cache(sequence, members, ret_rules_by_user, state_group) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index fd10d42f2f..4ee8c60257 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -20,8 +20,6 @@ import os import shutil from typing import Dict, Tuple -from six import iteritems - import twisted.internet.error import twisted.web.http from twisted.web.resource import Resource @@ -606,7 +604,7 @@ class MediaRepository(object): thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in iteritems(thumbnails): + for (t_width, t_height, t_type), t_method in thumbnails.items(): # Generate the thumbnail if t_method == "crop": t_byte_source = await defer_to_thread( diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 3bf330da49..e7e8b8e688 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from six import iteritems, string_types +from six import string_types from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder @@ -121,7 +121,7 @@ def copy_with_str_subst(x, substitutions): if isinstance(x, string_types): return x % substitutions if isinstance(x, dict): - return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)} + return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()} if isinstance(x, (list, tuple)): return [copy_with_str_subst(y) for y in x] diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 73f2cedb5c..4404ceff93 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from six import iteritems - from synapse.api.constants import ( EventTypes, LimitBlockingTypes, @@ -214,7 +212,7 @@ class ResourceLimitsServerNotices(object): referenced_events = list(pinned_state_event.content.get("pinned", [])) events = await self._store.get_events(referenced_events) - for event_id, event in iteritems(events): + for event_id, event in events.items(): if event.type != EventTypes.Message: continue if event.content.get("msgtype") == ServerNoticeMsgType: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 2fa529fcd0..50fd843f66 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -18,8 +18,6 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Optional, Set -from six import iteritems, itervalues - import attr from frozendict import frozendict from prometheus_client import Histogram @@ -144,7 +142,7 @@ class StateHandler(object): list(state.values()), get_prev_content=False ) state = { - key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map + key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } return state @@ -423,7 +421,7 @@ class StateHandler(object): state_res_store=StateResolutionStore(self.store), ) - new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} + new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()} return new_state @@ -505,8 +503,8 @@ class StateResolutionHandler(object): # resolve_events_with_store do it? new_state = {} conflicted_state = False - for st in itervalues(state_groups_ids): - for key, e_id in iteritems(st): + for st in state_groups_ids.values(): + for key, e_id in st.items(): if key in new_state: conflicted_state = True break @@ -520,7 +518,7 @@ class StateResolutionHandler(object): new_state = yield resolve_events_with_store( room_id, room_version, - list(itervalues(state_groups_ids)), + list(state_groups_ids.values()), event_map=event_map, state_res_store=state_res_store, ) @@ -561,12 +559,12 @@ def _make_state_cache_entry(new_state, state_groups_ids): # not get persisted. # first look for exact matches - new_state_event_ids = set(itervalues(new_state)) - for sg, state in iteritems(state_groups_ids): + new_state_event_ids = set(new_state.values()) + for sg, state in state_groups_ids.items(): if len(new_state_event_ids) != len(state): continue - old_state_event_ids = set(itervalues(state)) + old_state_event_ids = set(state.values()) if new_state_event_ids == old_state_event_ids: # got an exact match. return _StateCacheEntry(state=new_state, state_group=sg) @@ -579,8 +577,8 @@ def _make_state_cache_entry(new_state, state_groups_ids): prev_group = None delta_ids = None - for old_group, old_state in iteritems(state_groups_ids): - n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v} + for old_group, old_state in state_groups_ids.items(): + n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 9bf98d06f2..7b531a8337 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -17,8 +17,6 @@ import hashlib import logging from typing import Callable, Dict, List, Optional -from six import iteritems, iterkeys, itervalues - from twisted.internet import defer from synapse import event_auth @@ -70,11 +68,11 @@ def resolve_events_with_store( unconflicted_state, conflicted_state = _seperate(state_sets) needed_events = { - event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids + event_id for event_ids in conflicted_state.values() for event_id in event_ids } needed_event_count = len(needed_events) if event_map is not None: - needed_events -= set(iterkeys(event_map)) + needed_events -= set(event_map.keys()) logger.info( "Asking for %d/%d conflicted events", len(needed_events), needed_event_count @@ -102,11 +100,11 @@ def resolve_events_with_store( unconflicted_state, conflicted_state, state_map ) - new_needed_events = set(itervalues(auth_events)) + new_needed_events = set(auth_events.values()) new_needed_event_count = len(new_needed_events) new_needed_events -= needed_events if event_map is not None: - new_needed_events -= set(iterkeys(event_map)) + new_needed_events -= set(event_map.keys()) logger.info( "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count @@ -152,7 +150,7 @@ def _seperate(state_sets): conflicted_state = {} for state_set in state_set_iterator: - for key, value in iteritems(state_set): + for key, value in state_set.items(): # Check if there is an unconflicted entry for the state key. unconflicted_value = unconflicted_state.get(key) if unconflicted_value is None: @@ -178,7 +176,7 @@ def _seperate(state_sets): def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): auth_events = {} - for event_ids in itervalues(conflicted_state): + for event_ids in conflicted_state.values(): for event_id in event_ids: if event_id in state_map: keys = event_auth.auth_types_for_event(state_map[event_id]) @@ -194,7 +192,7 @@ def _resolve_with_state( unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map ): conflicted_state = {} - for key, event_ids in iteritems(conflicted_state_ids): + for key, event_ids in conflicted_state_ids.items(): events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] if len(events) > 1: conflicted_state[key] = events @@ -203,7 +201,7 @@ def _resolve_with_state( auth_events = { key: state_map[ev_id] - for key, ev_id in iteritems(auth_event_ids) + for key, ev_id in auth_event_ids.items() if ev_id in state_map } @@ -214,7 +212,7 @@ def _resolve_with_state( raise new_state = unconflicted_state_ids - for key, event in iteritems(resolved_state): + for key, event in resolved_state.items(): new_state[key] = event.event_id return new_state @@ -238,21 +236,21 @@ def _resolve_state_events(conflicted_state, auth_events): auth_events.update(resolved_state) - for key, events in iteritems(conflicted_state): + for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) - for key, events in iteritems(conflicted_state): + for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) resolved_state[key] = _resolve_auth_events(events, auth_events) auth_events.update(resolved_state) - for key, events in iteritems(conflicted_state): + for key, events in conflicted_state.items(): if key not in resolved_state: logger.debug("Resolving conflicted state %r:%r", key, events) resolved_state[key] = _resolve_normal_events(events, auth_events) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 18484e2fa6..e25bc5d264 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -18,8 +18,6 @@ import itertools import logging from typing import Dict, List, Optional -from six import iteritems, itervalues - from twisted.internet import defer import synapse.state @@ -87,7 +85,7 @@ def resolve_events_with_store( full_conflicted_set = set( itertools.chain( - itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff + itertools.chain.from_iterable(conflicted_state.values()), auth_diff ) ) @@ -572,7 +570,7 @@ def lexicographical_topological_sort(graph, key): # `(key(node), node)` so that sorting does the right thing zero_outdegree = [] - for node, edges in iteritems(graph): + for node, edges in graph.items(): if len(edges) == 0: zero_outdegree.append((key(node), node)) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 71f8d43a76..995d4764a9 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -15,8 +15,6 @@ import logging -from six import iteritems - from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -421,7 +419,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): ): self.database_engine.lock_table(txn, "user_ips") - for entry in iteritems(to_update): + for entry in to_update.items(): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: @@ -530,7 +528,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "user_agent": user_agent, "last_seen": last_seen, } - for (access_token, ip), (user_agent, last_seen) in iteritems(results) + for (access_token, ip), (user_agent, last_seen) in results.items() ] @wrap_as_background_process("prune_old_user_ips") diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index fb9f798e29..0ff0542453 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional, Set, Tuple -from six import iteritems - from canonicaljson import json from twisted.internet import defer @@ -208,7 +206,7 @@ class DeviceWorkerStore(SQLBaseStore): ) # add the updated cross-signing keys to the results list - for user_id, result in iteritems(cross_signing_keys_by_user): + for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec results.append(("org.matrix.signing_key_update", result)) @@ -269,7 +267,7 @@ class DeviceWorkerStore(SQLBaseStore): ) results = [] - for user_id, user_devices in iteritems(devices): + for user_id, user_devices in devices.items(): # The prev_id for the first row is always the last row before # `from_stream_id` prev_id = yield self._get_last_device_update_for_remote_user( @@ -493,7 +491,7 @@ class DeviceWorkerStore(SQLBaseStore): if devices: user_devices = devices[user_id] results = [] - for device_id, device in iteritems(user_devices): + for device_id, device in user_devices.items(): result = {"device_id": device_id} key_json = device.get("key_json", None) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 20698bfd16..1a0842d4b0 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -16,8 +16,6 @@ # limitations under the License. from typing import Dict, List -from six import iteritems - from canonicaljson import encode_canonical_json, json from twisted.enterprise.adbapi import Connection @@ -64,9 +62,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # Build the result structure, un-jsonify the results, and add the # "unsigned" section rv = {} - for user_id, device_keys in iteritems(results): + for user_id, device_keys in results.items(): rv[user_id] = {} - for device_id, device_info in iteritems(device_keys): + for device_id, device_info in device_keys.items(): r = db_to_json(device_info.pop("key_json")) r["unsigned"] = {} display_name = device_info["device_display_name"] diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 0321274de2..bc9f4f08ea 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -16,8 +16,6 @@ import logging -from six import iteritems - from canonicaljson import json from twisted.internet import defer @@ -455,7 +453,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): sql, ( _gen_entry(user_id, actions) - for user_id, actions in iteritems(user_id_actions) + for user_id, actions in user_id_actions.items() ), ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index a6572571b4..8a13101f1d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -21,7 +21,7 @@ from collections import OrderedDict, namedtuple from functools import wraps from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple -from six import integer_types, iteritems, text_type +from six import integer_types, text_type from six.moves import range import attr @@ -232,10 +232,10 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, new_state in iteritems(current_state_for_room): + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in iteritems(new_forward_extremeties): + for room_id, latest_event_ids in new_forward_extremeties.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -461,7 +461,7 @@ class PersistEventsStore: state_delta_by_room: Dict[str, DeltaState], stream_id: int, ): - for room_id, delta_state in iteritems(state_delta_by_room): + for room_id, delta_state in state_delta_by_room.items(): to_delete = delta_state.to_delete to_insert = delta_state.to_insert @@ -545,7 +545,7 @@ class PersistEventsStore: """, [ (room_id, key[0], key[1], ev_id, ev_id) - for key, ev_id in iteritems(to_insert) + for key, ev_id in to_insert.items() ], ) @@ -642,7 +642,7 @@ class PersistEventsStore: def _update_forward_extremities_txn( self, txn, new_forward_extremities, max_stream_order ): - for room_id, new_extrem in iteritems(new_forward_extremities): + for room_id, new_extrem in new_forward_extremities.items(): self.db.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) @@ -655,7 +655,7 @@ class PersistEventsStore: table="event_forward_extremities", values=[ {"event_id": ev_id, "room_id": room_id} - for room_id, new_extrem in iteritems(new_forward_extremities) + for room_id, new_extrem in new_forward_extremities.items() for ev_id in new_extrem ], ) @@ -672,7 +672,7 @@ class PersistEventsStore: "event_id": event_id, "stream_ordering": max_stream_order, } - for room_id, new_extrem in iteritems(new_forward_extremities) + for room_id, new_extrem in new_forward_extremities.items() for event_id in new_extrem ], ) @@ -727,7 +727,7 @@ class PersistEventsStore: event.depth, depth_updates.get(event.room_id, event.depth) ) - for room_id, depth in iteritems(depth_updates): + for room_id, depth in depth_updates.items(): self._update_min_depth_for_room_txn(txn, room_id, depth) def _update_outliers_txn(self, txn, events_and_contexts): @@ -1497,11 +1497,11 @@ class PersistEventsStore: table="event_to_state_groups", values=[ {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) + for event_id, state_group_id in state_groups.items() ], ) - for event_id, state_group_id in iteritems(state_groups): + for event_id, state_group_id in state_groups.items(): txn.call_after( self.store._get_state_group_for_event.prefill, (event_id,), diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 9768981891..587d4b91c1 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -19,8 +19,6 @@ import logging import re from typing import Optional -from six import iterkeys - from twisted.internet import defer from twisted.internet.defer import Deferred @@ -753,7 +751,7 @@ class RegistrationWorkerStore(SQLBaseStore): last_send_attempt, validated_at FROM threepid_validation_session WHERE %s """ % ( - " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + " AND ".join("%s = ?" % k for k in keyvalues.keys()), ) if validated is not None: diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 137ebac833..44bab65eac 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -17,8 +17,6 @@ import logging from typing import Iterable, List, Set -from six import iteritems, itervalues - from canonicaljson import json from twisted.internet import defer @@ -544,7 +542,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): users_in_room = {} member_event_ids = [ e_id - for key, e_id in iteritems(current_state_ids) + for key, e_id in current_state_ids.items() if key[0] == EventTypes.Member ] @@ -561,7 +559,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): users_in_room = dict(prev_res) member_event_ids = [ e_id - for key, e_id in iteritems(context.delta_ids) + for key, e_id in context.delta_ids.items() if key[0] == EventTypes.Member ] for etype, state_key in context.delta_ids: @@ -1101,7 +1099,7 @@ class _JoinedHostsCache(object): if state_entry.state_group == self.state_group: pass elif state_entry.prev_group == self.state_group: - for (typ, state_key), event_id in iteritems(state_entry.delta_ids): + for (typ, state_key), event_id in state_entry.delta_ids.items(): if typ != EventTypes.Member: continue @@ -1131,7 +1129,7 @@ class _JoinedHostsCache(object): self.state_group = state_entry.state_group else: self.state_group = object() - self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) + self._len = sum(len(v) for v in self.hosts_to_joined_users.values()) return frozenset(self.hosts_to_joined_users) def __len__(self): diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py index ff000bc9ec..be1fe97d79 100644 --- a/synapse/storage/data_stores/state/bg_updates.py +++ b/synapse/storage/data_stores/state/bg_updates.py @@ -15,8 +15,6 @@ import logging -from six import iteritems - from twisted.internet import defer from synapse.storage._base import SQLBaseStore @@ -280,7 +278,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): delta_state = { key: value - for key, value in iteritems(curr_state) + for key, value in curr_state.items() if prev_state.get(key, None) != value } @@ -316,7 +314,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(delta_state) + for key, state_id in delta_state.items() ], ) diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index f3ad1e4369..b720212e55 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -17,7 +17,6 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Set, Tuple -from six import iteritems from six.moves import range from twisted.internet import defer @@ -263,7 +262,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # And finally update the result dict, by filtering out any extra # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): + for group, group_state_dict in group_to_state_dict.items(): # We just replace any existing entries, as we will have loaded # everything we need from the database anyway. state[group] = state_filter.filter_state(group_state_dict) @@ -341,11 +340,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): else: non_member_types = non_member_filter.concrete_types() - for group, group_state_dict in iteritems(group_to_state_dict): + for group, group_state_dict in group_to_state_dict.items(): state_dict_members = {} state_dict_non_members = {} - for k, v in iteritems(group_state_dict): + for k, v in group_state_dict.items(): if k[0] == EventTypes.Member: state_dict_members[k] = v else: @@ -432,7 +431,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(delta_ids) + for key, state_id in delta_ids.items() ], ) else: @@ -447,7 +446,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(current_state_ids) + for key, state_id in current_state_ids.items() ], ) @@ -458,7 +457,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): current_member_state_ids = { s: ev - for (s, ev) in iteritems(current_state_ids) + for (s, ev) in current_state_ids.items() if s[0] == EventTypes.Member } txn.call_after( @@ -470,7 +469,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): current_non_member_state_ids = { s: ev - for (s, ev) in iteritems(current_state_ids) + for (s, ev) in current_state_ids.items() if s[0] != EventTypes.Member } txn.call_after( @@ -555,7 +554,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in iteritems(curr_state) + for key, state_id in curr_state.items() ], ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index b112ff3df2..645a70934c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,7 +29,6 @@ from typing import ( TypeVar, ) -from six import iteritems, iterkeys, itervalues from six.moves import intern, range from prometheus_client import Histogram @@ -259,7 +258,7 @@ class PerformanceCounters(object): def interval(self, interval_duration_secs, limit=3): counters = [] - for name, (count, cum_time) in iteritems(self.current_counters): + for name, (count, cum_time) in self.current_counters.items(): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) counters.append( ( @@ -1053,7 +1052,7 @@ class Database(object): sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) txn.execute(sql, list(keyvalues.values())) else: txn.execute(sql) @@ -1191,7 +1190,7 @@ class Database(object): clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) clauses = [clause] - for key, value in iteritems(keyvalues): + for key, value in keyvalues.items(): clauses.append("%s = ?" % (key,)) values.append(value) @@ -1212,7 +1211,7 @@ class Database(object): @staticmethod def simple_update_txn(txn, table, keyvalues, updatevalues): if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) else: where = "" @@ -1351,7 +1350,7 @@ class Database(object): clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) clauses = [clause] - for key, value in iteritems(keyvalues): + for key, value in keyvalues.items(): clauses.append("%s = ?" % (key,)) values.append(value) @@ -1388,7 +1387,7 @@ class Database(object): txn.close() if cache: - min_val = min(itervalues(cache)) + min_val = min(cache.values()) else: min_val = max_value diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index f159400a87..92dfd709bc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -20,7 +20,6 @@ import logging from collections import deque, namedtuple from typing import Iterable, List, Optional, Set, Tuple -from six import iteritems from six.moves import range from prometheus_client import Counter, Histogram @@ -218,7 +217,7 @@ class EventsPersistenceStorage(object): partitioned.setdefault(event.room_id, []).append((event, ctx)) deferreds = [] - for room_id, evs_ctxs in iteritems(partitioned): + for room_id, evs_ctxs in partitioned.items(): d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled ) @@ -319,7 +318,7 @@ class EventsPersistenceStorage(object): (event, context) ) - for room_id, ev_ctx_rm in iteritems(events_by_room): + for room_id, ev_ctx_rm in events_by_room.items(): latest_event_ids = await self.main_store.get_latest_event_ids_in_room( room_id ) @@ -674,7 +673,7 @@ class EventsPersistenceStorage(object): to_insert = { key: ev_id - for key, ev_id in iteritems(current_state) + for key, ev_id in current_state.items() if ev_id != existing_state.get(key) } diff --git a/synapse/storage/state.py b/synapse/storage/state.py index c522c80922..dc568476f4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -16,8 +16,6 @@ import logging from typing import Iterable, List, TypeVar -from six import iteritems, itervalues - import attr from twisted.internet import defer @@ -51,7 +49,7 @@ class StateFilter(object): # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: - self.types = {k: v for k, v in iteritems(self.types) if v is not None} + self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod def all(): @@ -150,7 +148,7 @@ class StateFilter(object): has_non_member_wildcard = self.include_others or any( state_keys is None - for t, state_keys in iteritems(self.types) + for t, state_keys in self.types.items() if t != EventTypes.Member ) @@ -199,7 +197,7 @@ class StateFilter(object): # First we build up a lost of clauses for each type/state_key combo clauses = [] - for etype, state_keys in iteritems(self.types): + for etype, state_keys in self.types.items(): if state_keys is None: clauses.append("(type = ?)") where_args.append(etype) @@ -251,7 +249,7 @@ class StateFilter(object): return dict(state_dict) filtered_state = {} - for k, v in iteritems(state_dict): + for k, v in state_dict.items(): typ, state_key = k if typ in self.types: state_keys = self.types[typ] @@ -279,7 +277,7 @@ class StateFilter(object): """ return self.include_others or any( - state_keys is None for state_keys in itervalues(self.types) + state_keys is None for state_keys in self.types.values() ) def concrete_types(self): @@ -292,7 +290,7 @@ class StateFilter(object): """ return [ (t, s) - for t, state_keys in iteritems(self.types) + for t, state_keys in self.types.items() if state_keys is not None for s in state_keys ] @@ -324,7 +322,7 @@ class StateFilter(object): member_filter = StateFilter.none() non_member_filter = StateFilter( - types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, + types={k: v for k, v in self.types.items() if k != EventTypes.Member}, include_others=self.include_others, ) @@ -366,7 +364,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) - groups = set(itervalues(event_to_groups)) + groups = set(event_to_groups.values()) group_to_state = yield self.stores.state._get_state_for_groups(groups) return group_to_state @@ -400,8 +398,8 @@ class StateGroupStorage(object): state_event_map = yield self.stores.main.get_events( [ ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) + for group_ids in group_to_ids.values() + for ev_id in group_ids.values() ], get_prev_content=False, ) @@ -409,10 +407,10 @@ class StateGroupStorage(object): return { group: [ state_event_map[v] - for v in itervalues(event_id_map) + for v in event_id_map.values() if v in state_event_map ] - for group, event_id_map in iteritems(group_to_ids) + for group, event_id_map in group_to_ids.items() } def _get_state_groups_from_groups( @@ -444,23 +442,23 @@ class StateGroupStorage(object): """ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) - groups = set(itervalues(event_to_groups)) + groups = set(event_to_groups.values()) group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) state_event_map = yield self.stores.main.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) + for k, v in group_to_state[group].items() if v in state_event_map } - for event_id, group in iteritems(event_to_groups) + for event_id, group in event_to_groups.items() } return {event: event_to_state[event] for event in event_ids} @@ -481,14 +479,14 @@ class StateGroupStorage(object): """ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) - groups = set(itervalues(event_to_groups)) + groups = set(event_to_groups.values()) group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) event_to_state = { event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) + for event_id, group in event_to_groups.items() } return {event: event_to_state[event] for event in event_ids} diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index cd48262420..64f35fc288 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -21,8 +21,6 @@ import threading from typing import Any, Tuple, Union, cast from weakref import WeakValueDictionary -from six import itervalues - from prometheus_client import Gauge from typing_extensions import Protocol @@ -281,7 +279,7 @@ class Cache(object): def invalidate_all(self): self.check_thread() self.cache.clear() - for entry in itervalues(self._pending_deferred_cache): + for entry in self._pending_deferred_cache.values(): entry.invalidate() self._pending_deferred_cache.clear() diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 2726b67b6d..89a3420f92 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -16,8 +16,6 @@ import logging from collections import OrderedDict -from six import iteritems, itervalues - from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -150,7 +148,7 @@ class ExpiringCache(object): keys_to_delete = set() - for key, cache_entry in iteritems(self._cache): + for key, cache_entry in self._cache.items(): if now - cache_entry.time > self._expiry_ms: keys_to_delete.add(key) @@ -170,7 +168,7 @@ class ExpiringCache(object): def __len__(self): if self.iterable: - return sum(len(entry.value) for entry in itervalues(self._cache)) + return sum(len(entry.value) for entry in self._cache.values()) else: return len(self._cache) diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 2ea4e4e911..ecd9948e79 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -1,7 +1,5 @@ from typing import Dict -from six import itervalues - SENTINEL = object() @@ -81,7 +79,7 @@ def iterate_tree_cache_entry(d): can contain dicts. """ if isinstance(d, dict): - for value_d in itervalues(d): + for value_d in d.values(): for value in iterate_tree_cache_entry(value_d): yield value else: diff --git a/synapse/visibility.py b/synapse/visibility.py index bab41182b9..780927cda1 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,7 +16,6 @@ import logging import operator -from six import iteritems, itervalues from six.moves import map from twisted.internet import defer @@ -298,7 +297,7 @@ def filter_events_for_server( # membership states for the requesting server to determine # if the server is either in the room or has been invited # into the room. - for ev in itervalues(state): + for ev in state.values(): if ev.type != EventTypes.Member: continue try: @@ -332,7 +331,7 @@ def filter_events_for_server( ) visibility_ids = set() - for sids in itervalues(event_to_state_ids): + for sids in event_to_state_ids.values(): hist = sids.get((EventTypes.RoomHistoryVisibility, "")) if hist: visibility_ids.add(hist) @@ -345,7 +344,7 @@ def filter_events_for_server( event_map = yield storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") - for e in itervalues(event_map) + for e in event_map.values() ) if not check_history_visibility_only: @@ -394,8 +393,8 @@ def filter_events_for_server( # event_id_to_state_key = { event_id: key - for key_to_eid in itervalues(event_to_state_ids) - for key, event_id in iteritems(key_to_eid) + for key_to_eid in event_to_state_ids.values() + for key, event_id in key_to_eid.items() } def include(typ, state_key): @@ -409,20 +408,16 @@ def filter_events_for_server( return state_key[idx + 1 :] == server_name event_map = yield storage.main.get_events( - [ - e_id - for e_id, key in iteritems(event_id_to_state_key) - if include(key[0], key[1]) - ] + [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] ) event_to_state = { e_id: { key: event_map[inner_e_id] - for key, inner_e_id in iteritems(key_to_eid) + for key, inner_e_id in key_to_eid.items() if inner_e_id in event_map } - for e_id, key_to_eid in iteritems(event_to_state_ids) + for e_id, key_to_eid in event_to_state_ids.items() } to_return = [] -- cgit 1.5.1 From cc32fa7358641b96f5d3dbc14d0cd068e676e256 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Jun 2020 16:20:34 -0400 Subject: Ensure the body is a string before comparing push rules. (#7701) --- changelog.d/7701.bugfix | 1 + synapse/push/push_rule_evaluator.py | 4 ++-- tests/push/test_push_rule_evaluator.py | 39 ++++++++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 11 deletions(-) create mode 100644 changelog.d/7701.bugfix (limited to 'synapse/push') diff --git a/changelog.d/7701.bugfix b/changelog.d/7701.bugfix new file mode 100644 index 0000000000..e5b10f75fd --- /dev/null +++ b/changelog.d/7701.bugfix @@ -0,0 +1 @@ +Do not break push rule evaluation when receiving an event with a non-string body. This is a long-standing bug. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 11032491af..aeac257a6e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -131,7 +131,7 @@ class PushRuleEvaluatorForEvent(object): # XXX: optimisation: cache our pattern regexps if condition["key"] == "content.body": body = self._event.content.get("body", None) - if not body: + if not body or not isinstance(body, str): return False return _glob_matches(pattern, body, word_boundary=True) @@ -147,7 +147,7 @@ class PushRuleEvaluatorForEvent(object): return False body = self._event.content.get("body", None) - if not body: + if not body or not isinstance(body, str): return False # Similar to _glob_matches, but do not treat display_name as a glob. diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 9ae6a87d7b..af35d23aea 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -21,7 +21,7 @@ from tests import unittest class PushRuleEvaluatorTestCase(unittest.TestCase): - def setUp(self): + def _get_evaluator(self, content): event = FrozenEvent( { "event_id": "$event_id", @@ -29,37 +29,58 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "sender": "@user:test", "state_key": "", "room_id": "@room:test", - "content": {"body": "foo bar baz"}, + "content": content, }, RoomVersions.V1, ) room_member_count = 0 sender_power_level = 0 power_levels = {} - self.evaluator = PushRuleEvaluatorForEvent( + return PushRuleEvaluatorForEvent( event, room_member_count, sender_power_level, power_levels ) def test_display_name(self): """Check for a matching display name in the body of the event.""" + evaluator = self._get_evaluator({"body": "foo bar baz"}) + condition = { "kind": "contains_display_name", } # Blank names are skipped. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "")) + self.assertFalse(evaluator.matches(condition, "@user:test", "")) # Check a display name that doesn't match. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found")) + self.assertFalse(evaluator.matches(condition, "@user:test", "not found")) # Check a display name which matches. - self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo")) + self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) # A display name that matches, but not a full word does not result in a match. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba")) + self.assertFalse(evaluator.matches(condition, "@user:test", "ba")) # A display name should not be interpreted as a regular expression. - self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]")) + self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]")) # A display name with spaces should work fine. - self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar")) + self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + + def test_no_body(self): + """Not having a body shouldn't break the evaluator.""" + evaluator = self._get_evaluator({}) + + condition = { + "kind": "contains_display_name", + } + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_invalid_body(self): + """A non-string body should not break the evaluator.""" + condition = { + "kind": "contains_display_name", + } + + for body in (1, True, {"foo": "bar"}): + evaluator = self._get_evaluator({"body": body}) + self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) -- cgit 1.5.1 From a3f11567d930b7da0db068c3b313f6f4abbf12a1 Mon Sep 17 00:00:00 2001 From: Dagfinn Ilmari Mannsåker Date: Tue, 16 Jun 2020 13:51:47 +0100 Subject: Replace all remaining six usage with native Python 3 equivalents (#7704) --- changelog.d/7704.misc | 1 + contrib/graph/graph3.py | 4 +--- scripts-dev/federation_client.py | 3 +-- scripts/synapse_port_db | 4 +--- setup.cfg | 2 +- synapse/_scripts/register_new_matrix_user.py | 2 -- synapse/api/errors.py | 7 +++---- synapse/api/filtering.py | 4 +--- synapse/api/urls.py | 3 +-- synapse/appservice/__init__.py | 4 +--- synapse/appservice/api.py | 3 +-- synapse/config/_base.py | 6 ++---- synapse/config/appservice.py | 13 ++++--------- synapse/config/tls.py | 4 +--- synapse/crypto/keyring.py | 6 ++---- synapse/events/utils.py | 4 +--- synapse/events/validator.py | 12 +++++------- synapse/federation/federation_base.py | 4 +--- synapse/federation/federation_server.py | 4 +--- synapse/federation/transport/client.py | 3 +-- synapse/groups/groups_server.py | 4 +--- synapse/handlers/cas_handler.py | 3 +-- synapse/handlers/federation.py | 9 ++++----- synapse/handlers/message.py | 4 +--- synapse/handlers/profile.py | 8 +++----- synapse/handlers/room.py | 4 +--- synapse/handlers/room_member.py | 7 +++---- synapse/http/client.py | 8 +++----- synapse/http/matrixfederationclient.py | 12 +++++------- synapse/http/server.py | 4 ++-- synapse/logging/formatter.py | 3 +-- synapse/push/mailer.py | 3 +-- synapse/push/push_rule_evaluator.py | 4 +--- synapse/python_dependencies.py | 1 - synapse/replication/http/_base.py | 6 ++---- synapse/rest/admin/users.py | 20 ++++++-------------- synapse/rest/client/v1/presence.py | 4 +--- synapse/rest/client/v1/room.py | 3 +-- synapse/rest/client/v2_alpha/account.py | 5 ++--- synapse/rest/client/v2_alpha/register.py | 11 +++-------- synapse/rest/client/v2_alpha/report_event.py | 10 ++++------ synapse/rest/consent/consent_resource.py | 5 ++--- synapse/rest/media/v1/_base.py | 3 +-- synapse/rest/media/v1/media_storage.py | 6 +----- synapse/rest/media/v1/preview_url_resource.py | 9 +++------ synapse/server_notices/consent_server_notices.py | 4 +--- synapse/storage/data_stores/main/event_federation.py | 3 +-- synapse/storage/data_stores/main/events.py | 10 +++------- .../storage/data_stores/main/events_bg_updates.py | 4 +--- .../data_stores/main/schema/delta/30/as_users.py | 2 -- synapse/storage/data_stores/main/search.py | 4 +--- synapse/storage/data_stores/main/stream.py | 2 -- synapse/storage/data_stores/main/tags.py | 2 -- synapse/storage/data_stores/state/store.py | 2 -- synapse/storage/database.py | 3 +-- synapse/storage/persist_events.py | 2 -- synapse/util/async_helpers.py | 2 -- synapse/util/caches/stream_change_cache.py | 4 +--- synapse/util/file_consumer.py | 2 +- synapse/util/frozenutils.py | 6 ++---- synapse/util/wheel_timer.py | 2 -- synapse/visibility.py | 2 -- synctl | 6 ++---- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v2_alpha/test_relations.py | 9 ++++----- tests/rest/media/v1/test_media_storage.py | 2 +- tests/server.py | 4 +--- tests/state/test_v2.py | 2 -- tests/test_server.py | 3 +-- tests/test_terms_auth.py | 9 ++++----- tests/util/test_file_consumer.py | 2 +- tests/util/test_linearizer.py | 2 -- tests/utils.py | 2 +- 73 files changed, 111 insertions(+), 237 deletions(-) create mode 100644 changelog.d/7704.misc (limited to 'synapse/push') diff --git a/changelog.d/7704.misc b/changelog.d/7704.misc new file mode 100644 index 0000000000..7838a613c8 --- /dev/null +++ b/changelog.d/7704.misc @@ -0,0 +1 @@ +Replace all remaining uses of `six` with native Python 3 equivalents. Contributed by @ilmari. diff --git a/contrib/graph/graph3.py b/contrib/graph/graph3.py index 7f9e5374a6..3154638520 100644 --- a/contrib/graph/graph3.py +++ b/contrib/graph/graph3.py @@ -24,8 +24,6 @@ import argparse from synapse.events import FrozenEvent from synapse.util.frozenutils import unfreeze -from six import string_types - def make_graph(file_name, room_id, file_prefix, limit): print("Reading lines") @@ -62,7 +60,7 @@ def make_graph(file_name, room_id, file_prefix, limit): for key, value in unfreeze(event.get_dict()["content"]).items(): if value is None: value = "" - elif isinstance(value, string_types): + elif isinstance(value, str): pass else: value = json.dumps(value) diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index 7c19e405d4..531010185d 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -21,8 +21,7 @@ import argparse import base64 import json import sys - -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse import nacl.signing import requests diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 9a0fbc61d8..a0d81c77c2 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -23,8 +23,6 @@ import sys import time import traceback -from six import string_types - import yaml from twisted.internet import defer, reactor @@ -635,7 +633,7 @@ class Porter(object): return bool(col) if isinstance(col, bytes): return bytearray(col) - elif isinstance(col, string_types) and "\0" in col: + elif isinstance(col, str) and "\0" in col: logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, diff --git a/setup.cfg b/setup.cfg index 12a7849081..f2bca272e1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY known_first_party = synapse known_tests=tests -known_compat = mock,six +known_compat = mock known_twisted=twisted,OpenSSL multi_line_output=3 include_trailing_comma=true diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index d528450c78..55cce2db22 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -23,8 +23,6 @@ import hmac import logging import sys -from six.moves import input - import requests as _requests import yaml diff --git a/synapse/api/errors.py b/synapse/api/errors.py index a07a54580d..5305038c21 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,10 +17,9 @@ """Contains exceptions and error codes.""" import logging +from http import HTTPStatus from typing import Dict, List -from six.moves import http_client - from canonicaljson import json from twisted.web import http @@ -173,7 +172,7 @@ class ConsentNotGivenError(SynapseError): consent_url (str): The URL where the user can give their consent """ super(ConsentNotGivenError, self).__init__( - code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN + code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN ) self._consent_uri = consent_uri @@ -193,7 +192,7 @@ class UserDeactivatedError(SynapseError): msg (str): The human-readable error message """ super(UserDeactivatedError, self).__init__( - code=http_client.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED + code=HTTPStatus.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED ) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 8b64d0a285..f988f62a1e 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -17,8 +17,6 @@ # limitations under the License. from typing import List -from six import text_type - import jsonschema from canonicaljson import json from jsonschema import FormatChecker @@ -313,7 +311,7 @@ class Filter(object): content = event.get("content", {}) # check if there is a string url field in the content for filtering purposes - contains_url = isinstance(content.get("url"), text_type) + contains_url = isinstance(content.get("url"), str) labels = content.get(EventContentFields.LABELS, []) return self.check_fields(room_id, sender, ev_type, labels, contains_url) diff --git a/synapse/api/urls.py b/synapse/api/urls.py index f34434bd67..bd03ebca5a 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -17,8 +17,7 @@ """Contains the URL paths to prefix various aspects of the server with. """ import hmac from hashlib import sha256 - -from six.moves.urllib.parse import urlencode +from urllib.parse import urlencode from synapse.config import ConfigError diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1b13e84425..0323256472 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -15,8 +15,6 @@ import logging import re -from six import string_types - from twisted.internet import defer from synapse.api.constants import EventTypes @@ -156,7 +154,7 @@ class ApplicationService(object): ) regex = regex_obj.get("regex") - if isinstance(regex, string_types): + if isinstance(regex, str): regex_obj["regex"] = re.compile(regex) # Pre-compile regex else: raise ValueError("Expected string for 'regex' in ns '%s'" % ns) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 57174da021..da9a5e86d4 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six.moves import urllib +import urllib from prometheus_client import Counter diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 30d1050a91..1391e5fc43 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -22,8 +22,6 @@ from collections import OrderedDict from textwrap import dedent from typing import Any, MutableMapping, Optional -from six import integer_types - import yaml @@ -117,7 +115,7 @@ class Config(object): @staticmethod def parse_size(value): - if isinstance(value, integer_types): + if isinstance(value, int): return value sizes = {"K": 1024, "M": 1024 * 1024} size = 1 @@ -129,7 +127,7 @@ class Config(object): @staticmethod def parse_duration(value): - if isinstance(value, integer_types): + if isinstance(value, int): return value second = 1000 minute = 60 * second diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index ca43e96bd1..8ed3e24258 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -14,9 +14,7 @@ import logging from typing import Dict - -from six import string_types -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse import yaml from netaddr import IPSet @@ -98,17 +96,14 @@ def load_appservices(hostname, config_files): def _load_appservice(hostname, as_info, config_filename): required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"] for field in required_string_fields: - if not isinstance(as_info.get(field), string_types): + if not isinstance(as_info.get(field), str): raise KeyError( "Required string field: '%s' (%s)" % (field, config_filename) ) # 'url' must either be a string or explicitly null, not missing # to avoid accidentally turning off push for ASes. - if ( - not isinstance(as_info.get("url"), string_types) - and as_info.get("url", "") is not None - ): + if not isinstance(as_info.get("url"), str) and as_info.get("url", "") is not None: raise KeyError( "Required string field or explicit null: 'url' (%s)" % (config_filename,) ) @@ -138,7 +133,7 @@ def _load_appservice(hostname, as_info, config_filename): ns, regex_obj, ) - if not isinstance(regex_obj.get("regex"), string_types): + if not isinstance(regex_obj.get("regex"), str): raise ValueError("Missing/bad type 'regex' key in %s", regex_obj) if not isinstance(regex_obj.get("exclusive"), bool): raise ValueError( diff --git a/synapse/config/tls.py b/synapse/config/tls.py index a65538562b..e368ea564d 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -20,8 +20,6 @@ from datetime import datetime from hashlib import sha256 from typing import List -import six - from unpaddedbase64 import encode_base64 from OpenSSL import SSL, crypto @@ -59,7 +57,7 @@ class TlsConfig(Config): logger.warning(ACME_SUPPORT_ENABLED_WARN) # hyperlink complains on py2 if this is not a Unicode - self.acme_url = six.text_type( + self.acme_url = str( acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory") ) self.acme_port = acme_config.get("port", 80) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index a9f4025bfe..dbfc3e8972 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +import urllib from collections import defaultdict -import six -from six.moves import urllib - import attr from signedjson.key import ( decode_verify_key_bytes, @@ -661,7 +659,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): for response in query_response["server_keys"]: # do this first, so that we can give useful errors thereafter server_name = response.get("server_name") - if not isinstance(server_name, six.string_types): + if not isinstance(server_name, str): raise KeyLookupError( "Malformed response from key notary server %s: invalid server_name" % (perspective_name,) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index dd340be9a7..f6b507977f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -16,8 +16,6 @@ import collections import re from typing import Any, Mapping, Union -from six import string_types - from frozendict import frozendict from twisted.internet import defer @@ -318,7 +316,7 @@ def serialize_event( if only_event_fields: if not isinstance(only_event_fields, list) or not all( - isinstance(f, string_types) for f in only_event_fields + isinstance(f, str) for f in only_event_fields ): raise TypeError("only_event_fields must be a list of strings") d = only_fields(d, only_event_fields) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index b001c64bb4..588d222f36 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import integer_types, string_types - from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions @@ -53,7 +51,7 @@ class EventValidator(object): event_strings = ["origin"] for s in event_strings: - if not isinstance(getattr(event, s), string_types): + if not isinstance(getattr(event, s), str): raise SynapseError(400, "'%s' not a string type" % (s,)) # Depending on the room version, ensure the data is spec compliant JSON. @@ -90,7 +88,7 @@ class EventValidator(object): max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if not isinstance(min_lifetime, integer_types): + if not isinstance(min_lifetime, int): raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -124,7 +122,7 @@ class EventValidator(object): ) if max_lifetime is not None: - if not isinstance(max_lifetime, integer_types): + if not isinstance(max_lifetime, int): raise SynapseError( code=400, msg="'max_lifetime' must be an integer", @@ -183,7 +181,7 @@ class EventValidator(object): strings.append("state_key") for s in strings: - if not isinstance(getattr(event, s), string_types): + if not isinstance(getattr(event, s), str): raise SynapseError(400, "Not '%s' a string type" % (s,)) RoomID.from_string(event.room_id) @@ -223,7 +221,7 @@ class EventValidator(object): for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) - if not isinstance(d[s], string_types): + if not isinstance(d[s], str): raise SynapseError(400, "'%s' not a string type" % (s,)) def _ensure_state_event(self, event): diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index b2ab5bd6a4..420df2385f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Iterable, List -import six - from twisted.internet import defer from twisted.internet.defer import Deferred, DeferredList from twisted.python.failure import Failure @@ -294,7 +292,7 @@ def event_from_pdu_json( assert_params_in_dict(pdu_json, ("type", "depth")) depth = pdu_json["depth"] - if not isinstance(depth, six.integer_types): + if not isinstance(depth, int): raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6920c23723..afe0a8238b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -17,8 +17,6 @@ import logging from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union -import six - from canonicaljson import json from prometheus_client import Counter @@ -751,7 +749,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: def _acl_entry_matches(server_name: str, acl_entry: str) -> Match: - if not isinstance(acl_entry, six.string_types): + if not isinstance(acl_entry, str): logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 060bf07197..9f99311419 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,10 +15,9 @@ # limitations under the License. import logging +import urllib from typing import Any, Dict, Optional -from six.moves import urllib - from twisted.internet import defer from synapse.api.constants import Membership diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8a9de913b3..8db8ab1b7b 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -17,8 +17,6 @@ import logging -from six import string_types - from synapse.api.errors import Codes, SynapseError from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute @@ -513,7 +511,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): for keyname in ("name", "avatar_url", "short_description", "long_description"): if keyname in content: value = content[keyname] - if not isinstance(value, string_types): + if not isinstance(value, str): raise SynapseError(400, "%r value is not a string" % (keyname,)) profile[keyname] = value diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 64aaa1335c..76f213723a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,11 +14,10 @@ # limitations under the License. import logging +import urllib import xml.etree.ElementTree as ET from typing import Dict, Optional, Tuple -from six.moves import urllib - from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d6038d9995..873f6bc39f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,10 +19,9 @@ import itertools import logging +from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Sequence, Tuple -from six.moves import http_client, zip - import attr from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json @@ -1194,7 +1193,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.prev_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: logger.warning( @@ -1202,7 +1201,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.auth_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -1545,7 +1544,7 @@ class FederationHandler(BaseHandler): # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 354da9a3b5..200127d291 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -17,8 +17,6 @@ import logging from typing import Optional, Tuple -from six import string_types - from canonicaljson import encode_canonical_json, json from twisted.internet import defer @@ -715,7 +713,7 @@ class EventCreationHandler(object): spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: - if not isinstance(spam_error, string_types): + if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 302efc1b9a..4b1e3073a8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -15,8 +15,6 @@ import logging -from six import raise_from - from twisted.internet import defer from synapse.api.errors import ( @@ -84,7 +82,7 @@ class BaseProfileHandler(BaseHandler): ) return result except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -135,7 +133,7 @@ class BaseProfileHandler(BaseHandler): ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -212,7 +210,7 @@ class BaseProfileHandler(BaseHandler): ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f7401373ca..950a84acd0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -24,8 +24,6 @@ import string from collections import OrderedDict from typing import Tuple -from six import string_types - from synapse.api.constants import ( EventTypes, JoinRules, @@ -595,7 +593,7 @@ class RoomCreationHandler(BaseHandler): "room_version", self.config.default_room_version.identifier ) - if not isinstance(room_version_id, string_types): + if not isinstance(room_version_id, str): raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0f7af982f0..27c479da9e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,10 +17,9 @@ import abc import logging +from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Tuple -from six.moves import http_client - from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError @@ -361,7 +360,7 @@ class RoomMemberHandler(object): if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -444,7 +443,7 @@ class RoomMemberHandler(object): is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( - http_client.FORBIDDEN, + HTTPStatus.FORBIDDEN, "You cannot reject this invite", errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM, ) diff --git a/synapse/http/client.py b/synapse/http/client.py index 3cef747a4d..8743e9839d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +import urllib from io import BytesIO -from six import raise_from, text_type -from six.moves import urllib - import treq from canonicaljson import encode_canonical_json, json from netaddr import IPAddress @@ -577,7 +575,7 @@ class SimpleHttpClient(object): # This can happen e.g. because the body is too large. raise except Exception as e: - raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e) + raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e return ( length, @@ -638,7 +636,7 @@ def encode_urlencode_args(args): def encode_urlencode_arg(arg): - if isinstance(arg, text_type): + if isinstance(arg, str): return arg.encode("utf-8") elif isinstance(arg, list): return [encode_urlencode_arg(i) for i in arg] diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2d47b9ea00..7b33b9f10a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -17,11 +17,9 @@ import cgi import logging import random import sys +import urllib from io import BytesIO -from six import raise_from, string_types -from six.moves import urllib - import attr import treq from canonicaljson import encode_canonical_json @@ -432,10 +430,10 @@ class MatrixFederationHttpClient(object): except TimeoutError as e: raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: - raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e) + raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e except Exception as e: logger.info("Failed to send request: %s", e) - raise_from(RequestSendFailed(e, can_retry=True), e) + raise RequestSendFailed(e, can_retry=True) from e incoming_responses_counter.labels( request.method, response.code @@ -487,7 +485,7 @@ class MatrixFederationHttpClient(object): # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException if response.code == 429: - raise_from(RequestSendFailed(e, can_retry=True), e) + raise RequestSendFailed(e, can_retry=True) from e else: raise e @@ -998,7 +996,7 @@ def encode_query_args(args): encoded_args = {} for k, vs in args.items(): - if isinstance(vs, string_types): + if isinstance(vs, str): vs = [vs] encoded_args[k] = [v.encode("UTF-8") for v in vs] diff --git a/synapse/http/server.py b/synapse/http/server.py index 2487a72171..6aa1dc1f92 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,10 +16,10 @@ import collections import html -import http.client import logging import types import urllib +from http import HTTPStatus from io import BytesIO from typing import Awaitable, Callable, TypeVar, Union @@ -188,7 +188,7 @@ def return_html_error( exc_info=(f.type, f.value, f.getTracebackObject()), ) else: - code = http.HTTPStatus.INTERNAL_SERVER_ERROR + code = HTTPStatus.INTERNAL_SERVER_ERROR msg = "Internal server error" logger.error( diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index fbf570c756..d736ad5b9b 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -16,8 +16,7 @@ import logging import traceback - -from six import StringIO +from io import StringIO class LogFormatter(logging.Formatter): diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index d57a66a697..dda560b2c2 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -17,12 +17,11 @@ import email.mime.multipart import email.utils import logging import time +import urllib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Iterable, List, TypeVar -from six.moves import urllib - import bleach import jinja2 diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index aeac257a6e..8e0d3a416d 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -18,8 +18,6 @@ import logging import re from typing import Pattern -from six import string_types - from synapse.events import EventBase from synapse.types import UserID from synapse.util.caches import register_cache @@ -244,7 +242,7 @@ def _flatten_dict(d, prefix=[], result=None): if result is None: result = {} for key, value in d.items(): - if isinstance(value, string_types): + if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() elif hasattr(value, "items"): _flatten_dict(value, prefix=(prefix + [key]), result=result) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8ec1a619a2..d655aba35c 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -66,7 +66,6 @@ REQUIREMENTS = [ "pymacaroons>=0.13.0", "msgpack>=0.5.2", "phonenumbers>=8.2.0", - "six>=1.10", "prometheus_client>=0.0.18,<0.8.0", # we use attr.validators.deep_iterable, which arrived in 19.1.0 "attrs>=19.1.0", diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 793cef6c26..9caf1e80c1 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,12 +16,10 @@ import abc import logging import re +import urllib from inspect import signature from typing import Dict, List, Tuple -from six import raise_from -from six.moves import urllib - from twisted.internet import defer from synapse.api.errors import ( @@ -220,7 +218,7 @@ class ReplicationEndpoint(object): # importantly, not stack traces everywhere) raise e.to_synapse_error() except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to talk to master"), e) + raise SynapseError(502, "Failed to talk to master") from e return result diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index fefc8f71fa..e4330c39d6 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -16,9 +16,7 @@ import hashlib import hmac import logging import re - -from six import text_type -from six.moves import http_client +from http import HTTPStatus from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -215,10 +213,7 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: - if ( - not isinstance(body["password"], text_type) - or len(body["password"]) > 512 - ): + if not isinstance(body["password"], str) or len(body["password"]) > 512: raise SynapseError(400, "Invalid password") else: new_password = body["password"] @@ -252,7 +247,7 @@ class UserRestServletV2(RestServlet): password = body.get("password") password_hash = None if password is not None: - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -370,10 +365,7 @@ class UserRegisterServlet(RestServlet): 400, "username must be specified", errcode=Codes.BAD_JSON ) else: - if ( - not isinstance(body["username"], text_type) - or len(body["username"]) > 512 - ): + if not isinstance(body["username"], str) or len(body["username"]) > 512: raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") @@ -386,7 +378,7 @@ class UserRegisterServlet(RestServlet): ) else: password = body["password"] - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_bytes = password.encode("utf-8") @@ -477,7 +469,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 7cf007d35e..970fdd5834 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,8 +17,6 @@ """ import logging -from six import string_types - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -73,7 +71,7 @@ class PresenceStatusRestServlet(RestServlet): if "status_msg" in content: state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], string_types): + if not isinstance(state["status_msg"], str): raise SynapseError(400, "status_msg must be a string.") if content: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 105e0cf4d2..46811abbfa 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -18,8 +18,7 @@ import logging import re from typing import List, Optional - -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse from canonicaljson import json diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1dc4a3247f..923bcb9f85 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -15,8 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from six.moves import http_client +from http import HTTPStatus from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError @@ -321,7 +320,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b9ffe86b2a..141a3f5fac 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -18,8 +18,6 @@ import hmac import logging from typing import List, Union -from six import string_types - import synapse import synapse.api.auth import synapse.types @@ -413,7 +411,7 @@ class RegisterRestServlet(RestServlet): # in sessions. Pull out the username/password provided to us. if "password" in body: password = body.pop("password") - if not isinstance(password, string_types) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(password) @@ -425,10 +423,7 @@ class RegisterRestServlet(RestServlet): desired_username = None if "username" in body: - if ( - not isinstance(body["username"], string_types) - or len(body["username"]) > 512 - ): + if not isinstance(body["username"], str) or len(body["username"]) > 512: raise SynapseError(400, "Invalid username") desired_username = body["username"] @@ -453,7 +448,7 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) - if isinstance(desired_username, string_types): + if isinstance(desired_username, str): result = await self._do_appservice_registration( desired_username, access_token, body ) diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index f067b5edac..e15927c4ea 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -14,9 +14,7 @@ # limitations under the License. import logging - -from six import string_types -from six.moves import http_client +from http import HTTPStatus from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( @@ -47,15 +45,15 @@ class ReportEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ("reason", "score")) - if not isinstance(body["reason"], string_types): + if not isinstance(body["reason"], str): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'reason' must be a string", Codes.BAD_JSON, ) if not isinstance(body["score"], int): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", Codes.BAD_JSON, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 1ddf9997ff..049c16b236 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -16,10 +16,9 @@ import hmac import logging from hashlib import sha256 +from http import HTTPStatus from os import path -from six.moves import http_client - import jinja2 from jinja2 import TemplateNotFound @@ -223,4 +222,4 @@ class ConsentResource(DirectServeResource): ) if not compare_digest(want_mac, userhmac): - raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 3689777266..595849f9d5 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -16,8 +16,7 @@ import logging import os - -from six.moves import urllib +import urllib from twisted.internet import defer from twisted.protocols.basic import FileSender diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 683a79c966..79cb0dddbe 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -17,9 +17,6 @@ import contextlib import logging import os import shutil -import sys - -import six from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -117,12 +114,11 @@ class MediaStorage(object): with open(fname, "wb") as f: yield f, fname, finish except Exception: - t, v, tb = sys.exc_info() try: os.remove(fname) except Exception: pass - six.reraise(t, v, tb) + raise if not finished_called: raise Exception("Finished callback not called") diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index f206605727..f67e0fb3ec 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -24,10 +24,7 @@ import shutil import sys import traceback from typing import Dict, Optional - -import six -from six import string_types -from six.moves import urllib_parse as urlparse +from urllib import parse as urlparse from canonicaljson import json @@ -188,7 +185,7 @@ class PreviewUrlResource(DirectServeResource): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. og = cache_result["og"] - if isinstance(og, six.text_type): + if isinstance(og, str): og = og.encode("utf8") return og @@ -631,7 +628,7 @@ def _iterate_over_text(tree, *tags_to_ignore): if el is None: return - if isinstance(el, string_types): + if isinstance(el, str): yield el elif el.tag not in tags_to_ignore: # el.text is the text before the first child, so we can immediately diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index e7e8b8e688..3bfc8d7278 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from six import string_types - from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError @@ -118,7 +116,7 @@ def copy_with_str_subst(x, substitutions): Returns: copy of x """ - if isinstance(x, string_types): + if isinstance(x, str): return x % substitutions if isinstance(x, dict): return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()} diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 24ce8c4330..a6bb3221ff 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,10 +14,9 @@ # limitations under the License. import itertools import logging +from queue import Empty, PriorityQueue from typing import Dict, List, Optional, Set, Tuple -from six.moves.queue import Empty, PriorityQueue - from twisted.internet import defer from synapse.api.errors import StoreError diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 8a13101f1d..cfd24d2f06 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -21,9 +21,6 @@ from collections import OrderedDict, namedtuple from functools import wraps from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple -from six import integer_types, text_type -from six.moves import range - import attr from canonicaljson import json from prometheus_client import Counter @@ -893,8 +890,7 @@ class PersistEventsStore: "received_ts": self._clock.time_msec(), "sender": event.sender, "contains_url": ( - "url" in event.content - and isinstance(event.content["url"], text_type) + "url" in event.content and isinstance(event.content["url"], str) ), } for event, _ in events_and_contexts @@ -1345,10 +1341,10 @@ class PersistEventsStore: ): if ( "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), integer_types) + and not isinstance(event.content.get("min_lifetime"), int) ) or ( "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), integer_types) + and not isinstance(event.content.get("max_lifetime"), int) ): # Ignore the event if one of the value isn't an integer. return diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index f54c8b1ee0..62d28f44dc 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -15,8 +15,6 @@ import logging -from six import text_type - from canonicaljson import json from twisted.internet import defer @@ -133,7 +131,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], text_type) + contains_url &= isinstance(content["url"], str) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py index 9b95411fb6..b42c02710a 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/as_users.py +++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py @@ -13,8 +13,6 @@ # limitations under the License. import logging -from six.moves import range - from synapse.config.appservice import load_appservices logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 13f49d8060..a8381dc577 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -17,8 +17,6 @@ import logging import re from collections import namedtuple -from six import string_types - from canonicaljson import json from twisted.internet import defer @@ -180,7 +178,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): # skip over it. continue - if not isinstance(value, string_types): + if not isinstance(value, str): # If the event body, name or topic isn't a string # then skip over it continue diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index e89f0bffb5..379d758b5d 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -40,8 +40,6 @@ import abc import logging from collections import namedtuple -from six.moves import range - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 4219018302..f8c776be3f 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -16,8 +16,6 @@ import logging -from six.moves import range - from canonicaljson import json from twisted.internet import defer diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index b720212e55..5db9f20135 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Dict, Iterable, List, Set, Tuple -from six.moves import range - from twisted.internet import defer from synapse.api.constants import EventTypes diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 645a70934c..3be20c866a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -16,6 +16,7 @@ # limitations under the License. import logging import time +from sys import intern from time import monotonic as monotonic_time from typing import ( Any, @@ -29,8 +30,6 @@ from typing import ( TypeVar, ) -from six.moves import intern, range - from prometheus_client import Histogram from twisted.enterprise import adbapi diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 92dfd709bc..ec894a91cb 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -20,8 +20,6 @@ import logging from collections import deque, namedtuple from typing import Iterable, List, Optional, Set, Tuple -from six.moves import range - from prometheus_client import Counter, Histogram from twisted.internet import defer diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f7af2bca7f..df42486351 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -19,8 +19,6 @@ import logging from contextlib import contextmanager from typing import Dict, Sequence, Set, Union -from six.moves import range - import attr from twisted.internet import defer diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 2a161bf244..c541bf4579 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -17,8 +17,6 @@ import logging import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union -from six import integer_types - from sortedcontainers import SortedDict from synapse.types import Collection @@ -88,7 +86,7 @@ class StreamChangeCache: def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ - assert type(stream_pos) in integer_types + assert isinstance(stream_pos, int) if stream_pos < self._earliest_known_stream_pos: self.metrics.inc_misses() diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py index 8b17d1c8b8..6a3f6177b1 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import queue +import queue from twisted.internet import threads diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 9815bb8667..eab78dd256 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import binary_type, text_type - from canonicaljson import json from frozendict import frozendict @@ -26,7 +24,7 @@ def freeze(o): if isinstance(o, frozendict): return o - if isinstance(o, (binary_type, text_type)): + if isinstance(o, (bytes, str)): return o try: @@ -41,7 +39,7 @@ def unfreeze(o): if isinstance(o, (dict, frozendict)): return dict({k: unfreeze(v) for k, v in o.items()}) - if isinstance(o, (binary_type, text_type)): + if isinstance(o, (bytes, str)): return o try: diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index 9bf6a44f75..023beb5ede 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import range - class _Entry(object): __slots__ = ["end_key", "queue"] diff --git a/synapse/visibility.py b/synapse/visibility.py index 780927cda1..3dfd4af26c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,8 +16,6 @@ import logging import operator -from six.moves import map - from twisted.internet import defer from synapse.api.constants import EventTypes, Membership diff --git a/synctl b/synctl index 960fd357ee..ca398b84bd 100755 --- a/synctl +++ b/synctl @@ -26,8 +26,6 @@ import subprocess import sys import time -from six import iteritems - import yaml from synapse.config import find_config_files @@ -251,7 +249,7 @@ def main(): os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) cache_factors = config.get("synctl_cache_factors", {}) - for cache_name, factor in iteritems(cache_factors): + for cache_name, factor in cache_factors.items(): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) worker_configfiles = [] @@ -362,7 +360,7 @@ def main(): if worker.cache_factor: os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor) - for cache_name, factor in iteritems(worker.cache_factors): + for cache_name, factor in worker.cache_factors.items(): os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) if not start_worker(worker.app, configfile, worker.configfile): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 4886bbb401..5ccda8b2bd 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,9 +19,9 @@ """Tests REST events for /rooms paths.""" import json +from urllib import parse as urlparse from mock import Mock -from six.moves.urllib import parse as urlparse from twisted.internet import defer diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index c7e5859970..fd641a7c2f 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -15,8 +15,7 @@ import itertools import json - -import six +import urllib from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -134,7 +133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): # Make sure next_batch has something in it that looks like it could be a # valid token. self.assertIsInstance( - channel.json_body.get("next_batch"), six.string_types, channel.json_body + channel.json_body.get("next_batch"), str, channel.json_body ) def test_repeated_paginate_relations(self): @@ -278,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) for _ in range(20): from_token = "" if prev_token: @@ -670,7 +669,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): query = "" if key: - query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) + query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) original_id = parent_id if parent_id else self.parent_id diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 1ca648ef2b..aefe648bdb 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -20,9 +20,9 @@ import tempfile from binascii import unhexlify from io import BytesIO from typing import Optional +from urllib import parse from mock import Mock -from six.moves.urllib import parse import attr import PIL.Image as Image diff --git a/tests/server.py b/tests/server.py index 1644710aa0..a5e57c52fa 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,8 +2,6 @@ import json import logging from io import BytesIO -from six import text_type - import attr from zope.interface import implementer @@ -174,7 +172,7 @@ def make_request( if not path.startswith(b"/"): path = b"/" + path - if isinstance(content, text_type): + if isinstance(content, str): content = content.encode("utf8") site = FakeSite() diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index a44960203e..cdc347bc53 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -15,8 +15,6 @@ import itertools -from six.moves import zip - import attr from synapse.api.constants import EventTypes, JoinRules, Membership diff --git a/tests/test_server.py b/tests/test_server.py index adae3c6e08..3f6f468e5b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,8 +14,7 @@ import logging import re - -from six import StringIO +from io import StringIO from twisted.internet.defer import Deferred from twisted.python.failure import Failure diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 5c2817cf28..b89798336c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -14,7 +14,6 @@ import json -import six from mock import Mock from twisted.test.proto_helpers import MemoryReactorClock @@ -60,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"401", channel.result) self.assertTrue(channel.json_body is not None) - self.assertIsInstance(channel.json_body["session"], six.text_type) + self.assertIsInstance(channel.json_body["session"], str) self.assertIsInstance(channel.json_body["flows"], list) for flow in channel.json_body["flows"]: @@ -125,6 +124,6 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) self.assertTrue(channel.json_body is not None) - self.assertIsInstance(channel.json_body["user_id"], six.text_type) - self.assertIsInstance(channel.json_body["access_token"], six.text_type) - self.assertIsInstance(channel.json_body["device_id"], six.text_type) + self.assertIsInstance(channel.json_body["user_id"], str) + self.assertIsInstance(channel.json_body["access_token"], str) + self.assertIsInstance(channel.json_body["device_id"], str) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index e90e08d1c0..8d6627ec33 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -15,9 +15,9 @@ import threading +from io import StringIO from mock import NonCallableMock -from six import StringIO from twisted.internet import defer, reactor diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index ca3858b184..0e52811948 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six.moves import range - from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError diff --git a/tests/utils.py b/tests/utils.py index 7ba8a31ff3..4d17355a5c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,9 +21,9 @@ import time import uuid import warnings from inspect import getcallargs +from urllib import parse as urlparse from mock import Mock, patch -from six.moves.urllib import parse as urlparse from twisted.internet import defer, reactor -- cgit 1.5.1 From f6f7511a4c0548b17bd1cdabebd0ffad9ea73bc7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Jun 2020 17:10:28 +0100 Subject: Refactor getting replication updates from database. (#7636) The aim here is to make it easier to reason about when streams are limited and when they're not, by moving the logic into the database functions themselves. This should mean we can kill of `db_query_to_update_function` function. --- changelog.d/7636.misc | 1 + synapse/handlers/presence.py | 29 +++++++- synapse/handlers/typing.py | 40 ++++++++--- synapse/push/pusherpool.py | 4 +- synapse/replication/tcp/streams/_base.py | 29 +++----- synapse/storage/data_stores/main/events_worker.py | 41 ++++++++++-- synapse/storage/data_stores/main/presence.py | 41 ++++++++++-- synapse/storage/data_stores/main/push_rule.py | 56 ++++++++++++---- synapse/storage/data_stores/main/receipts.py | 82 +++++++++++++++++++---- 9 files changed, 251 insertions(+), 72 deletions(-) create mode 100644 changelog.d/7636.misc (limited to 'synapse/push') diff --git a/changelog.d/7636.misc b/changelog.d/7636.misc new file mode 100644 index 0000000000..f93149502e --- /dev/null +++ b/changelog.d/7636.misc @@ -0,0 +1 @@ +Refactor getting replication updates from database. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2e8914be14..d2f25ae12a 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,7 +25,7 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import Dict, Iterable, List, Set +from typing import Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager @@ -773,7 +773,9 @@ class PresenceHandler(BasePresenceHandler): return False - async def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -785,10 +787,31 @@ class PresenceHandler(BasePresenceHandler): - last_user_sync_ts(int) - status_msg(int) - currently_active(int) + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ + # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id, limit) + rows = await self.store.get_all_presence_updates( + instance_name, last_id, current_id, limit + ) return rows def notify_new_event(self): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c7bc14c623..4330abb9f7 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import List +from typing import List, Tuple from twisted.internet import defer @@ -259,14 +259,31 @@ class TypingHandler(object): ) async def get_all_typing_updates( - self, last_id: int, current_id: int, limit: int - ) -> List[dict]: - """Get up to `limit` typing updates between the given tokens, earliest - updates first. + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for typing replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return [] + return [], current_id, False changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id @@ -280,9 +297,16 @@ class TypingHandler(object): serial = self._room_serials[room_id] if last_id < serial <= current_id: typing = self._room_typing[room_id] - rows.append((serial, room_id, list(typing))) + rows.append((serial, [room_id, list(typing)])) rows.sort() - return rows[:limit] + + limited = False + if len(rows) > limit: + rows = rows[:limit] + current_id = rows[-1][0] + limited = True + + return rows, current_id, limited def get_current_token(self): return self._latest_room_serial diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 88d203aa44..f6a5458681 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -215,11 +215,9 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - updated_receipts = yield self.store.get_all_updated_receipts( + users_affected = yield self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) - # This returns a tuple, user_id is at index 3 - users_affected = {r[3] for r in updated_receipts} for u in users_affected: if u in self.pushers: diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4acefc8a96..f196eff072 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -264,7 +264,7 @@ class BackfillStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_backfill_token), - db_query_to_update_function(store.get_all_new_backfill_event_rows), + store.get_all_new_backfill_event_rows, ) @@ -291,9 +291,7 @@ class PresenceStream(Stream): if hs.config.worker_app is None: # on the master, query the presence handler presence_handler = hs.get_presence_handler() - update_function = db_query_to_update_function( - presence_handler.get_all_presence_updates - ) + update_function = presence_handler.get_all_presence_updates else: # Query master process update_function = make_http_update_function(hs, self.NAME) @@ -318,9 +316,7 @@ class TypingStream(Stream): if hs.config.worker_app is None: # on the master, query the typing handler - update_function = db_query_to_update_function( - typing_handler.get_all_typing_updates - ) + update_function = typing_handler.get_all_typing_updates else: # Query master process update_function = make_http_update_function(hs, self.NAME) @@ -352,7 +348,7 @@ class ReceiptsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), - db_query_to_update_function(store.get_all_updated_receipts), + store.get_all_updated_receipts, ) @@ -367,26 +363,17 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() + super(PushRulesStream, self).__init__( - hs.get_instance_name(), self._current_token, self._update_function + hs.get_instance_name(), + self._current_token, + self.store.get_all_push_rule_updates, ) def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function( - self, instance_name: str, from_token: Token, to_token: Token, limit: int - ): - rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - - limited = False - if len(rows) == limit: - to_token = rows[-1][0] - limited = True - - return [(row[0], (row[2],)) for row in rows], to_token, limited - class PushersStream(Stream): """A user has added/changed/removed a pusher diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 213d69100a..a48c7a96ca 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -1077,9 +1077,32 @@ class EventsWorkerStore(SQLBaseStore): "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + async def get_all_new_backfill_event_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for backfill replication stream, including all new + backfilled events and events that have gone from being outliers to not. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_new_backfill_event_rows(txn): sql = ( @@ -1094,10 +1117,12 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() + new_event_updates = [(row[0], row[1:]) for row in txn] + limited = False if len(new_event_updates) == limit: upper_bound = new_event_updates[-1][0] + limited = True else: upper_bound = current_id @@ -1114,11 +1139,15 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) + new_event_updates.extend((row[0], row[1:]) for row in txn) - return new_event_updates + if len(new_event_updates) >= limit: + upper_bound = new_event_updates[-1][0] + limited = True - return self.db.runInteraction( + return new_event_updates, upper_bound, limited + + return await self.db.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index dab31e0c2d..7574612619 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple + from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause @@ -73,9 +75,32 @@ class PresenceStore(SQLBaseStore): ) txn.execute(sql + clause, [stream_id] + list(args)) - def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for presence replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_presence_updates_txn(txn): sql = """ @@ -89,9 +114,17 @@ class PresenceStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + updates = [(row[0], row[1:]) for row in txn] + + upper_bound = current_id + limited = False + if len(updates) >= limit: + upper_bound = updates[-1][0] + limited = True + + return updates, upper_bound, limited - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index ef8f40959f..f6e78ca590 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -16,7 +16,7 @@ import abc import logging -from typing import Union +from typing import List, Tuple, Union from canonicaljson import json @@ -348,23 +348,53 @@ class PushRulesWorkerStore( results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" + async def get_all_push_rule_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for push_rules replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) + sql = """ + SELECT stream_id, user_id + FROM push_rules_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] + + limited = False + upper_bound = current_id + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index d4a7163049..8f5505bd67 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -16,6 +16,7 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json @@ -267,26 +268,79 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results - def get_all_updated_receipts(self, last_id, current_id, limit=None): + def get_users_sent_receipts_between(self, last_id: int, current_id: int): + """Get all users who sent receipts between `last_id` exclusive and + `current_id` inclusive. + + Returns: + Deferred[List[str]] + """ + if last_id == current_id: return defer.succeed([]) - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) + def _get_users_sent_receipts_between_txn(txn): + sql = """ + SELECT DISTINCT user_id FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (last_id, current_id)) - return [r[0:5] + (json.loads(r[5]),) for r in txn] + return [r[0] for r in txn] return self.db.runInteraction( + "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn + ) + + async def get_all_updated_receipts( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for receipts replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_receipts_txn(txn): + sql = """ + SELECT stream_id, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + + updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn] + + limited = False + upper_bound = current_id + + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited + + return await self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) -- cgit 1.5.1 From 5a5cf6460ec4b4bb3a07813c36717b5a8d4a697c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 17 Jun 2020 15:10:09 +0100 Subject: Fix unread counts in sync * Always return an unread_count in get_unread_event_push_actions_by_room_for_user * Don't always expect unread_count to be there so we don't take out sync entirely if something goes wrong --- changelog.d/7716.feature | 1 + synapse/push/push_tools.py | 2 +- synapse/storage/data_stores/main/event_push_actions.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7716.feature (limited to 'synapse/push') diff --git a/changelog.d/7716.feature b/changelog.d/7716.feature new file mode 100644 index 0000000000..ecc3ffd8d5 --- /dev/null +++ b/changelog.d/7716.feature @@ -0,0 +1 @@ +Add a per-room counter for unread messages in responses to `/sync` requests. Implements [MSC2625](https://github.com/matrix-org/matrix-doc/pull/2625). diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 9f264ca4a4..4ea683fee0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -42,7 +42,7 @@ def get_badge_count(store, user_id): # We're populating this badge using the unread_count (instead of the # notify_count) as this badge is the number of missed messages, not the # number of missed notifications. - badge += 1 if notifs["unread_count"] else 0 + badge += 1 if notifs.get("unread_count") else 0 return badge diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index ba1b33a0a9..815d52ab4c 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -123,7 +123,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, last_read_event_id)) results = txn.fetchall() if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} + return {"notify_count": 0, "highlight_count": 0, "unread_count": 0} stream_ordering = results[0][0] -- cgit 1.5.1 From 7d2824395faf66347d4534635b408e6ea21d110d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 18 Jun 2020 10:47:06 +0100 Subject: add a comment --- synapse/push/httppusher.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/push') diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index eaaa7afc91..ed60dbc1bf 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -129,6 +129,8 @@ class HttpPusher(object): @defer.inlineCallbacks def _update_badge(self): + # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems + # to be largely redundant. perhaps we can remove it. badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) yield self._send_badge(badge) -- cgit 1.5.1 From 74d3e177f0443f27e670f0b99299d715c58fd238 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 1 Jul 2020 11:08:25 +0100 Subject: Back out MSC2625 implementation (#7761) --- changelog.d/7673.feature | 1 - changelog.d/7716.feature | 1 - changelog.d/7761.feature | 1 + synapse/handlers/sync.py | 3 - synapse/push/bulk_push_rule_evaluator.py | 7 +- synapse/push/push_tools.py | 5 +- synapse/rest/client/v1/push_rule.py | 4 +- .../storage/data_stores/main/event_push_actions.py | 133 +++++---------------- .../delta/58/07push_summary_unread_count.sql | 23 ---- tests/replication/slave/storage/test_events.py | 19 +-- tests/storage/test_event_push_actions.py | 45 +++---- 11 files changed, 53 insertions(+), 189 deletions(-) delete mode 100644 changelog.d/7673.feature delete mode 100644 changelog.d/7716.feature create mode 100644 changelog.d/7761.feature delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql (limited to 'synapse/push') diff --git a/changelog.d/7673.feature b/changelog.d/7673.feature deleted file mode 100644 index ecc3ffd8d5..0000000000 --- a/changelog.d/7673.feature +++ /dev/null @@ -1 +0,0 @@ -Add a per-room counter for unread messages in responses to `/sync` requests. Implements [MSC2625](https://github.com/matrix-org/matrix-doc/pull/2625). diff --git a/changelog.d/7716.feature b/changelog.d/7716.feature deleted file mode 100644 index ecc3ffd8d5..0000000000 --- a/changelog.d/7716.feature +++ /dev/null @@ -1 +0,0 @@ -Add a per-room counter for unread messages in responses to `/sync` requests. Implements [MSC2625](https://github.com/matrix-org/matrix-doc/pull/2625). diff --git a/changelog.d/7761.feature b/changelog.d/7761.feature new file mode 100644 index 0000000000..c97864677a --- /dev/null +++ b/changelog.d/7761.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0b82aa72a6..4c7524493e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1893,9 +1893,6 @@ class SyncHandler(object): if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] unread_notifications["highlight_count"] = notifs["highlight_count"] - unread_notifications["org.matrix.msc2625.unread_count"] = notifs[ - "unread_count" - ] sync_result_builder.joined.append(room_sync) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 5b00602a56..43ffe6faf0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -189,11 +189,8 @@ class BulkPushRuleEvaluator(object): ) if matches: actions = [x for x in rule["actions"] if x != "dont_notify"] - if ( - "notify" in actions - or "org.matrix.msc2625.mark_unread" in actions - ): - # Push rules say we should act on this event. + if actions and "notify" in actions: + # Push rules say we should notify the user of this event actions_by_user[uid] = actions break diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 4ea683fee0..5dae4648c0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -39,10 +39,7 @@ def get_badge_count(store, user_id): ) # return one badge count per conversation, as count per # message is so noisy as to be almost useless - # We're populating this badge using the unread_count (instead of the - # notify_count) as this badge is the number of missed messages, not the - # number of missed notifications. - badge += 1 if notifs.get("unread_count") else 0 + badge += 1 if notifs["notify_count"] else 0 return badge diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index f563b3dc35..9fd4908136 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2020 The Matrix.org Foundation C.I.C. +# Copyright 2014-2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -267,7 +267,7 @@ def _check_actions(actions): raise InvalidRuleException("No actions found") for a in actions: - if a in ["notify", "dont_notify", "coalesce", "org.matrix.msc2625.mark_unread"]: + if a in ["notify", "dont_notify", "coalesce"]: pass elif isinstance(a, dict) and "set_tweak" in a: pass diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 815d52ab4c..bc9f4f08ea 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2015-2020 The Matrix.org Foundation C.I.C. +# Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +15,7 @@ # limitations under the License. import logging -from typing import Dict, Tuple -import attr from canonicaljson import json from twisted.internet import defer @@ -37,16 +36,6 @@ DEFAULT_HIGHLIGHT_ACTION = [ ] -@attr.s -class EventPushSummary: - """Summary of pending event push actions for a given user in a given room.""" - - unread_count = attr.ib(type=int) - stream_ordering = attr.ib(type=int) - old_user_id = attr.ib(type=str) - notif_count = attr.ib(type=int) - - def _serialize_action(actions, is_highlight): """Custom serializer for actions. This allows us to "compress" common actions. @@ -123,7 +112,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, last_read_event_id)) results = txn.fetchall() if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0, "unread_count": 0} + return {"notify_count": 0, "highlight_count": 0} stream_ordering = results[0][0] @@ -133,42 +122,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): - # First get number of actions, grouped on whether the action notifies. + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 sql = ( - "SELECT count(*), notif" + "SELECT count(*)" " FROM event_push_actions ea" " WHERE" " user_id = ?" " AND room_id = ?" " AND stream_ordering > ?" - " GROUP BY notif" ) - txn.execute(sql, (user_id, room_id, stream_ordering)) - rows = txn.fetchall() - # We should get a maximum number of two rows: one for notif = 0, which is the - # number of actions that contribute to the unread_count but not to the - # notify_count, and one for notif = 1, which is the number of actions that - # contribute to both counters. If one or both rows don't appear, then the - # value for the matching counter should be 0. - unread_count = 0 - notify_count = 0 - for row in rows: - # We always increment unread_count because actions that notify also - # contribute to it. - unread_count += row[0] - if row[1] == 1: - notify_count = row[0] - elif row[1] != 0: - logger.warning( - "Unexpected value %d for column 'notif' in table" - " 'event_push_actions'", - row[1], - ) + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + notify_count = row[0] if row else 0 txn.execute( """ - SELECT notif_count, unread_count FROM event_push_summary + SELECT notif_count FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering > ? """, (room_id, user_id, stream_ordering), @@ -176,7 +148,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): rows = txn.fetchall() if rows: notify_count += rows[0][0] - unread_count += rows[0][1] # Now get the number of highlights sql = ( @@ -193,11 +164,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): row = txn.fetchone() highlight_count = row[0] if row else 0 - return { - "unread_count": unread_count, - "notify_count": notify_count, - "highlight_count": highlight_count, - } + return {"notify_count": notify_count, "highlight_count": highlight_count} @defer.inlineCallbacks def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): @@ -255,7 +222,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -284,7 +250,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering ASC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -357,7 +322,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -386,7 +350,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): " AND ep.user_id = ?" " AND ep.stream_ordering > ?" " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" " ORDER BY ep.stream_ordering DESC LIMIT ?" ) args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] @@ -436,7 +399,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _get_if_maybe_push_in_range_for_user_txn(txn): sql = """ SELECT 1 FROM event_push_actions - WHERE user_id = ? AND stream_ordering > ? AND notif = 1 + WHERE user_id = ? AND stream_ordering > ? LIMIT 1 """ @@ -465,15 +428,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): return # This is a helper function for generating the necessary tuple that - # can be used to insert into the `event_push_actions_staging` table. + # can be used to inert into the `event_push_actions_staging` table. def _gen_entry(user_id, actions): is_highlight = 1 if _action_has_highlight(actions) else 0 - notif = 0 if "org.matrix.msc2625.mark_unread" in actions else 1 return ( event_id, # event_id column user_id, # user_id column _serialize_action(actions, is_highlight), # actions column - notif, # notif column + 1, # notif column is_highlight, # highlight column ) @@ -855,51 +817,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, - coalesce(old.%s, 0) + upd.cnt, + coalesce(old.notif_count, 0) + upd.notif_count, upd.stream_ordering, old.user_id FROM ( - SELECT user_id, room_id, count(*) as cnt, + SELECT user_id, room_id, count(*) as notif_count, max(stream_ordering) as stream_ordering FROM event_push_actions WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0 - %s GROUP BY user_id, room_id ) AS upd LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - # First get the count of unread messages. - txn.execute( - sql % ("unread_count", ""), - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - # We need to merge both lists into a single object because we might not have the - # same amount of rows in each of them. In this case we use a dict indexed on the - # user ID and room ID to make it easier to populate. - summaries = {} # type: Dict[Tuple[str, str], EventPushSummary] - for row in txn: - summaries[(row[0], row[1])] = EventPushSummary( - unread_count=row[2], - stream_ordering=row[3], - old_user_id=row[4], - notif_count=0, - ) - - # Then get the count of notifications. - txn.execute( - sql % ("notif_count", "AND notif = 1"), - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - # notif_rows is populated based on a subset of the query used to populate - # unread_rows, so we can be sure that there will be no KeyError here. - for row in txn: - summaries[(row[0], row[1])].notif_count = row[2] + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) + rows = txn.fetchall() - logger.info("Rotating notifications, handling %d rows", len(summaries)) + logger.info("Rotating notifications, handling %d rows", len(rows)) # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the @@ -909,34 +844,22 @@ class EventPushActionsStore(EventPushActionsWorkerStore): table="event_push_summary", values=[ { - "user_id": user_id, - "room_id": room_id, - "notif_count": summary.notif_count, - "unread_count": summary.unread_count, - "stream_ordering": summary.stream_ordering, + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], } - for ((user_id, room_id), summary) in summaries.items() - if summary.old_user_id is None + for row in rows + if row[4] is None ], ) txn.executemany( """ - UPDATE event_push_summary - SET notif_count = ?, unread_count = ?, stream_ordering = ? + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? WHERE user_id = ? AND room_id = ? """, - ( - ( - summary.notif_count, - summary.unread_count, - summary.stream_ordering, - user_id, - room_id, - ) - for ((user_id, room_id), summary) in summaries.items() - if summary.old_user_id is not None - ), + ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), ) txn.execute( diff --git a/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql b/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql deleted file mode 100644 index f1459ef7f0..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store the number of unread messages, i.e. messages that triggered either a notify --- action or a mark_unread one. -ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0; - --- Pre-populate the new column with the count of pending notifications. --- We expect event_push_summary to be relatively small, so we can do this update --- synchronously without impacting Synapse's startup time too much. -UPDATE event_push_summary SET unread_count = notif_count; \ No newline at end of file diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index cd8680e812..1a88c7fb80 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 0, "unread_count": 0}, + {"highlight_count": 0, "notify_count": 0}, ) self.persist( @@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 0, "notify_count": 1, "unread_count": 1}, + {"highlight_count": 0, "notify_count": 1}, ) self.persist( @@ -188,20 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2, "unread_count": 2}, - ) - - self.persist( - type="m.room.message", - msgtype="m.text", - body="world", - push_actions=[(USER_ID_2, ["org.matrix.msc2625.mark_unread"])], - ) - self.replicate() - self.check( - "get_unread_event_push_actions_by_room_for_user", - [ROOM_ID, USER_ID_2, event1.event_id], - {"highlight_count": 1, "notify_count": 2, "unread_count": 3}, + {"highlight_count": 1, "notify_count": 2}, ) def test_get_rooms_for_user_with_stream_ordering(self): diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 303dc8571c..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -22,10 +22,6 @@ import tests.utils USER_ID = "@user:example.com" -MARK_UNREAD = [ - "org.matrix.msc2625.mark_unread", - {"set_tweak": "highlight", "value": False}, -] PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] HIGHLIGHT = [ "notify", @@ -59,17 +55,13 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): user_id = "@user1235:example.com" @defer.inlineCallbacks - def _assert_counts(unread_count, notif_count, highlight_count): + def _assert_counts(noitf_count, highlight_count): counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( counts, - { - "unread_count": unread_count, - "notify_count": notif_count, - "highlight_count": highlight_count, - }, + {"notify_count": noitf_count, "highlight_count": highlight_count}, ) @defer.inlineCallbacks @@ -104,23 +96,23 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): stream, ) - yield _assert_counts(0, 0, 0) + yield _assert_counts(0, 0) yield _inject_actions(1, PlAIN_NOTIF) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _rotate(2) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _inject_actions(3, PlAIN_NOTIF) - yield _assert_counts(2, 2, 0) + yield _assert_counts(2, 0) yield _rotate(4) - yield _assert_counts(2, 2, 0) + yield _assert_counts(2, 0) yield _inject_actions(5, PlAIN_NOTIF) yield _mark_read(3, 3) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _mark_read(5, 5) - yield _assert_counts(0, 0, 0) + yield _assert_counts(0, 0) yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) @@ -129,22 +121,17 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): table="event_push_actions", keyvalues={"1": 1}, desc="" ) - yield _assert_counts(1, 1, 0) + yield _assert_counts(1, 0) yield _mark_read(7, 7) - yield _assert_counts(0, 0, 0) + yield _assert_counts(0, 0) - yield _inject_actions(8, MARK_UNREAD) - yield _assert_counts(1, 0, 0) + yield _inject_actions(8, HIGHLIGHT) + yield _assert_counts(1, 1) yield _rotate(9) - yield _assert_counts(1, 0, 0) - - yield _inject_actions(10, HIGHLIGHT) - yield _assert_counts(2, 1, 1) - yield _rotate(11) - yield _assert_counts(2, 1, 1) - yield _rotate(12) - yield _assert_counts(2, 1, 1) + yield _assert_counts(1, 1) + yield _rotate(10) + yield _assert_counts(1, 1) @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): -- cgit 1.5.1 From e5808c4cfbec60f11f358bea529b321e94751ec9 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Wed, 1 Jul 2020 17:02:31 +0100 Subject: Hack to add push priority to push notifications (#7765) * Remove obsolete comment about ancient temporary code Signed-off-by: Olivier Wilkinson (reivilibre) * Implement hack to set push priority based on whether the tweaks indicate the event might cause effects. * Changelog for 7765 Signed-off-by: Olivier Wilkinson (reivilibre) * Antilint * Add tests for push priority Signed-off-by: Olivier Wilkinson (reivilibre) * Update synapse/push/httppusher.py Co-authored-by: Brendan Abolivier * Antilint * Remove needless invites from tests. Co-authored-by: Brendan Abolivier --- changelog.d/7765.misc | 1 + synapse/push/httppusher.py | 17 ++- tests/push/test_http.py | 352 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 362 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7765.misc (limited to 'synapse/push') diff --git a/changelog.d/7765.misc b/changelog.d/7765.misc new file mode 100644 index 0000000000..fa9cfd24cb --- /dev/null +++ b/changelog.d/7765.misc @@ -0,0 +1 @@ +Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index ed60dbc1bf..2fac07593b 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -20,6 +20,7 @@ from prometheus_client import Counter from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled +from synapse.api.constants import EventTypes from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException @@ -305,12 +306,23 @@ class HttpPusher(object): @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): + priority = "low" + if ( + event.type == EventTypes.Encrypted + or tweaks.get("highlight") + or tweaks.get("sound") + ): + # HACK send our push as high priority only if it generates a sound, highlight + # or may do so (i.e. is encrypted so has unknown effects). + priority = "high" + if self.data.get("format") == "event_id_only": d = { "notification": { "event_id": event.event_id, "room_id": event.room_id, "counts": {"unread": badge}, + "prio": priority, "devices": [ { "app_id": self.app_id, @@ -334,9 +346,8 @@ class HttpPusher(object): "room_id": event.room_id, "type": event.type, "sender": event.user_id, - "counts": { # -- we don't mark messages as read yet so - # we have no way of knowing - # Just set the badge to 1 until we have read receipts + "prio": priority, + "counts": { "unread": badge, # 'missed_calls': 2 }, diff --git a/tests/push/test_http.py b/tests/push/test_http.py index baf9c785f4..b567868b02 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -25,7 +25,6 @@ from tests.unittest import HomeserverTestCase class HTTPPusherTests(HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -35,7 +34,6 @@ class HTTPPusherTests(HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor, clock): - self.push_attempts = [] m = Mock() @@ -90,9 +88,6 @@ class HTTPPusherTests(HomeserverTestCase): # Create a room room = self.helper.create_room_as(user_id, tok=access_token) - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) - # The other user joins self.helper.join(room=room, user=other_user_id, tok=other_access_token) @@ -157,3 +152,350 @@ class HTTPPusherTests(HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) + + def test_sends_high_priority_for_encrypted(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to an encrypted message. + This will happen both in 1:1 rooms and larger rooms. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send an encrypted event + # I know there'd normally be set-up of an encrypted room first + # but this will do for our purposes + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "6lImKbzK51MzWLwHh8tUM3UBBSBrLlgup/OOCGTvumM", + "ciphertext": "AwgAErABoRxwpMipdgiwXgu46rHiWQ0DmRj0qUlPrMraBUDk" + "leTnJRljpuc7IOhsYbLY3uo2WI0ab/ob41sV+3JEIhODJPqH" + "TK7cEZaIL+/up9e+dT9VGF5kRTWinzjkeqO8FU5kfdRjm+3w" + "0sy3o1OCpXXCfO+faPhbV/0HuK4ndx1G+myNfK1Nk/CxfMcT" + "BT+zDS/Df/QePAHVbrr9uuGB7fW8ogW/ulnydgZPRluusFGv" + "J3+cg9LoPpZPAmv5Me3ec7NtdlfN0oDZ0gk3TiNkkhsxDG9Y" + "YcNzl78USI0q8+kOV26Bu5dOBpU4WOuojXZHJlP5lMgdzLLl" + "EQ0", + "session_id": "IigqfNWLL+ez/Is+Duwp2s4HuCZhFG9b9CZKTYHtQ4A", + "device_id": "AHQDUSTAAA", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Add yet another person — we want to make this room not a 1:1 + # (as encrypted messages in a 1:1 currently have tweaks applied + # so it doesn't properly exercise the condition of all encrypted + # messages need to be high). + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another encrypted event + self.helper.send_event( + room, + "m.room.encrypted", + content={ + "ciphertext": "AwgAEoABtEuic/2DF6oIpNH+q/PonzlhXOVho8dTv0tzFr5m" + "9vTo50yabx3nxsRlP2WxSqa8I07YftP+EKWCWJvTkg6o7zXq" + "6CK+GVvLQOVgK50SfvjHqJXN+z1VEqj+5mkZVN/cAgJzoxcH" + "zFHkwDPJC8kQs47IHd8EO9KBUK4v6+NQ1uE/BIak4qAf9aS/" + "kI+f0gjn9IY9K6LXlah82A/iRyrIrxkCkE/n0VfvLhaWFecC" + "sAWTcMLoF6fh1Jpke95mljbmFSpsSd/eEQw", + "device_id": "SRCFTWTHXO", + "session_id": "eMA+bhGczuTz1C5cJR1YbmrnnC6Goni4lbvS5vJ1nG4", + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "rC/XSIAiYrVGSuaHMop8/pTZbku4sQKBZwRwukgnN1c", + }, + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") + + def test_sends_high_priority_for_one_to_one_only(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message in a one-to-one room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Hi!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority — this is a one-to-one room + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Yet another user joins + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Check no push notifications are sent regarding the membership changes + # (that would confuse the test) + self.pump() + self.assertEqual(len(self.push_attempts), 1) + + # Send another event + self.helper.send(room, body="Welcome!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_mention(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message containing the user's display name. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other users join + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send(room, body="Oh, user, hello!", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time with no mention + self.helper.send(room, body="Are you there?", tok=other_access_token) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") + + def test_sends_high_priority_for_atroom(self): + """ + The HTTP pusher will send pushes at high priority if they correspond + to a message that contains @room. + """ + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register a third user + yet_another_user_id = self.register_user("yetanotheruser", "pass") + yet_another_access_token = self.login("yetanotheruser", "pass") + + # Create a room (as other_user so the power levels are compatible with + # other_user sending @room). + room = self.helper.create_room_as(other_user_id, tok=other_access_token) + + # The other users join + self.helper.join(room=room, user=user_id, tok=access_token) + self.helper.join( + room=room, user=yet_another_user_id, tok=yet_another_access_token + ) + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Send a message + self.helper.send( + room, + body="@room eeek! There's a spider on the table!", + tok=other_access_token, + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + # Make the push succeed + self.push_attempts[0][0].callback({}) + self.pump() + + # Check our push made it with high priority + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(self.push_attempts[0][1], "example.com") + self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high") + + # Send another event, this time as someone without the power of @room + self.helper.send( + room, body="@room the spider is gone", tok=yet_another_access_token + ) + + # Advance time a bit, so the pusher will register something has happened + self.pump() + self.assertEqual(len(self.push_attempts), 2) + self.assertEqual(self.push_attempts[1][1], "example.com") + + # check that this is low-priority + self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") -- cgit 1.5.1 From 57feeab364325374b14ff67ac97c288983cc5cde Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 6 Jul 2020 11:43:41 +0100 Subject: Don't ignore `set_tweak` actions with no explicit `value`. (#7766) * Fix spec compliance; tweaks without values are valid (default to True, which is only concretely specified for `highlight`, but it seems only reasonable to generalise) * Changelog for 7766. * Add documentation to `tweaks_for_actions` May as well tidy up when I'm here. * Add a test for `tweaks_for_actions` --- changelog.d/7766.bugfix | 1 + synapse/push/push_rule_evaluator.py | 31 +++++++++++++++++++++++++++---- tests/push/test_push_rule_evaluator.py | 17 +++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7766.bugfix (limited to 'synapse/push') diff --git a/changelog.d/7766.bugfix b/changelog.d/7766.bugfix new file mode 100644 index 0000000000..ec5ecd8055 --- /dev/null +++ b/changelog.d/7766.bugfix @@ -0,0 +1 @@ +Fix to not ignore `set_tweak` actions in Push Rules that have no `value`, as permitted by the specification. diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 8e0d3a416d..2d79ada189 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,7 +16,7 @@ import logging import re -from typing import Pattern +from typing import Any, Dict, List, Pattern, Union from synapse.events import EventBase from synapse.types import UserID @@ -72,13 +72,36 @@ def _test_ineq_condition(condition, number): return False -def tweaks_for_actions(actions): +def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]: + """ + Converts a list of actions into a `tweaks` dict (which can then be passed to + the push gateway). + + This function ignores all actions other than `set_tweak` actions, and treats + absent `value`s as `True`, which agrees with the only spec-defined treatment + of absent `value`s (namely, for `highlight` tweaks). + + Args: + actions: list of actions + e.g. [ + {"set_tweak": "a", "value": "AAA"}, + {"set_tweak": "b", "value": "BBB"}, + {"set_tweak": "highlight"}, + "notify" + ] + + Returns: + dictionary of tweaks for those actions + e.g. {"a": "AAA", "b": "BBB", "highlight": True} + """ tweaks = {} for a in actions: if not isinstance(a, dict): continue - if "set_tweak" in a and "value" in a: - tweaks[a["set_tweak"]] = a["value"] + if "set_tweak" in a: + # value is allowed to be absent in which case the value assumed + # should be True. + tweaks[a["set_tweak"]] = a.get("value", True) return tweaks diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index af35d23aea..1f4b5ca2ac 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -15,6 +15,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent +from synapse.push import push_rule_evaluator from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent from tests import unittest @@ -84,3 +85,19 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): for body in (1, True, {"foo": "bar"}): evaluator = self._get_evaluator({"body": body}) self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) + + def test_tweaks_for_actions(self): + """ + This tests the behaviour of tweaks_for_actions. + """ + + actions = [ + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + "notify", + ] + + self.assertEqual( + push_rule_evaluator.tweaks_for_actions(actions), + {"sound": "default", "highlight": True}, + ) -- cgit 1.5.1 From f886a699169e416dca7a8d23d3874dfada24629d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 14 Jul 2020 10:00:53 +0100 Subject: Correctly pass app_name to all email templates. (#7829) We didn't do this for e.g. registration emails. --- changelog.d/7829.bugfix | 1 + synapse/push/mailer.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7829.bugfix (limited to 'synapse/push') diff --git a/changelog.d/7829.bugfix b/changelog.d/7829.bugfix new file mode 100644 index 0000000000..dcbf385de6 --- /dev/null +++ b/changelog.d/7829.bugfix @@ -0,0 +1 @@ +Fix bug where we did not always pass in `app_name` or `server_name` to email templates, including e.g. for registration emails. diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index dda560b2c2..a10dba0af6 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -269,7 +269,6 @@ class Mailer(object): user_id, app_id, email_address ), "summary_text": summary_text, - "app_name": self.app_name, "rooms": rooms, "reason": reason, } @@ -278,7 +277,7 @@ class Mailer(object): email_address, "[%s] %s" % (self.app_name, summary_text), template_vars ) - async def send_email(self, email_address, subject, template_vars): + async def send_email(self, email_address, subject, extra_template_vars): """Send an email with the given information and template text""" try: from_string = self.hs.config.email_notif_from % {"app": self.app_name} @@ -291,6 +290,13 @@ class Mailer(object): if raw_to == "": raise RuntimeError("Invalid 'to' address") + template_vars = { + "app_name": self.app_name, + "server_name": self.hs.config.server.server_name, + } + + template_vars.update(extra_template_vars) + html_text = self.template_html.render(**template_vars) html_part = MIMEText(html_text, "html", "utf8") -- cgit 1.5.1 From 85223106f3c04d2aa4747906412ef05435409eec Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 14 Jul 2020 19:10:42 +0100 Subject: Allow email subjects to be customised through Synapse's configuration (#7846) --- changelog.d/7846.feature | 1 + docs/sample_config.yaml | 71 ++++++++++++++++++++++++- synapse/config/emailconfig.py | 118 +++++++++++++++++++++++++++++++++++++++--- synapse/push/mailer.py | 51 +++++++----------- 4 files changed, 202 insertions(+), 39 deletions(-) create mode 100644 changelog.d/7846.feature (limited to 'synapse/push') diff --git a/changelog.d/7846.feature b/changelog.d/7846.feature new file mode 100644 index 0000000000..997376fe42 --- /dev/null +++ b/changelog.d/7846.feature @@ -0,0 +1 @@ +Allow email subjects to be customised through Synapse's configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9d94495464..e059fd2c35 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1949,8 +1949,8 @@ email: # #notif_from: "Your Friendly %(app)s homeserver " - # app_name defines the default value for '%(app)s' in notif_from. It - # defaults to 'Matrix'. + # app_name defines the default value for '%(app)s' in notif_from and email + # subjects. It defaults to 'Matrix'. # #app_name: my_branded_matrix_server @@ -2019,6 +2019,73 @@ email: # #template_dir: "res/templates" + # Subjects to use when sending emails from Synapse. + # + # The placeholder '%(app)s' will be replaced with the value of the 'app_name' + # setting above, or by a value dictated by the Matrix client application. + # + # If a subject isn't overridden in this configuration file, the value used as + # its example will be used. + # + #subjects: + + # Subjects for notification emails. + # + # On top of the '%(app)s' placeholder, these can use the following + # placeholders: + # + # * '%(person)s', which will be replaced by the display name of the user(s) + # that sent the message(s), e.g. "Alice and Bob". + # * '%(room)s', which will be replaced by the name of the room the + # message(s) have been sent to, e.g. "My super room". + # + # See the example provided for each setting to see which placeholder can be + # used and how to use them. + # + # Subject to use to notify about one message from one or more user(s) in a + # room which has a name. + #message_from_person_in_room: "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room..." + # + # Subject to use to notify about one message from one or more user(s) in a + # room which doesn't have a name. + #message_from_person: "[%(app)s] You have a message on %(app)s from %(person)s..." + # + # Subject to use to notify about multiple messages from one or more users in + # a room which doesn't have a name. + #messages_from_person: "[%(app)s] You have messages on %(app)s from %(person)s..." + # + # Subject to use to notify about multiple messages in a room which has a + # name. + #messages_in_room: "[%(app)s] You have messages on %(app)s in the %(room)s room..." + # + # Subject to use to notify about multiple messages in multiple rooms. + #messages_in_room_and_others: "[%(app)s] You have messages on %(app)s in the %(room)s room and others..." + # + # Subject to use to notify about multiple messages from multiple persons in + # multiple rooms. This is similar to the setting above except it's used when + # the room in which the notification was triggered has no name. + #messages_from_person_and_others: "[%(app)s] You have messages on %(app)s from %(person)s and others..." + # + # Subject to use to notify about an invite to a room which has a name. + #invite_from_person_to_room: "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s..." + # + # Subject to use to notify about an invite to a room which doesn't have a + # name. + #invite_from_person: "[%(app)s] %(person)s has invited you to chat on %(app)s..." + + # Subject for emails related to account administration. + # + # On top of the '%(app)s' placeholder, these one can use the + # '%(server_name)s' placeholder, which will be replaced by the value of the + # 'server_name' setting in your Synapse configuration. + # + # Subject to use when sending a password reset email. + #password_reset: "[%(server_name)s] Password reset" + # + # Subject to use when sending a verification email to assert an address's + # ownership. + #email_validation: "[%(server_name)s] Validate your email" + # Password providers allow homeserver administrators to integrate # their Synapse installation with existing authentication methods diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index b1dc7ad502..a63acbdc63 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -22,6 +22,7 @@ import os from enum import Enum from typing import Optional +import attr import pkg_resources from ._base import Config, ConfigError @@ -32,6 +33,33 @@ Password reset emails are enabled on this homeserver due to a partial %s """ +DEFAULT_SUBJECTS = { + "message_from_person_in_room": "[%(app)s] You have a message on %(app)s from %(person)s in the %(room)s room...", + "message_from_person": "[%(app)s] You have a message on %(app)s from %(person)s...", + "messages_from_person": "[%(app)s] You have messages on %(app)s from %(person)s...", + "messages_in_room": "[%(app)s] You have messages on %(app)s in the %(room)s room...", + "messages_in_room_and_others": "[%(app)s] You have messages on %(app)s in the %(room)s room and others...", + "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...", + "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...", + "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...", + "password_reset": "[%(server_name)s] Password reset", + "email_validation": "[%(server_name)s] Validate your email", +} + + +@attr.s +class EmailSubjectConfig: + message_from_person_in_room = attr.ib(type=str) + message_from_person = attr.ib(type=str) + messages_from_person = attr.ib(type=str) + messages_in_room = attr.ib(type=str) + messages_in_room_and_others = attr.ib(type=str) + messages_from_person_and_others = attr.ib(type=str) + invite_from_person = attr.ib(type=str) + invite_from_person_to_room = attr.ib(type=str) + password_reset = attr.ib(type=str) + email_validation = attr.ib(type=str) + class EmailConfig(Config): section = "email" @@ -294,8 +322,17 @@ class EmailConfig(Config): if not os.path.isfile(p): raise ConfigError("Unable to find email template file %s" % (p,)) + subjects_config = email_config.get("subjects", {}) + subjects = {} + + for key, default in DEFAULT_SUBJECTS.items(): + subjects[key] = subjects_config.get(key, default) + + self.email_subjects = EmailSubjectConfig(**subjects) + def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """\ + return ( + """\ # Configuration for sending emails from Synapse. # email: @@ -323,17 +360,17 @@ class EmailConfig(Config): # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # - # The placeholder '%(app)s' will be replaced by the application name, + # The placeholder '%%(app)s' will be replaced by the application name, # which is normally 'app_name' (below), but may be overridden by the # Matrix client application. # - # Note that the placeholder must be written '%(app)s', including the + # Note that the placeholder must be written '%%(app)s', including the # trailing 's'. # - #notif_from: "Your Friendly %(app)s homeserver " + #notif_from: "Your Friendly %%(app)s homeserver " - # app_name defines the default value for '%(app)s' in notif_from. It - # defaults to 'Matrix'. + # app_name defines the default value for '%%(app)s' in notif_from and email + # subjects. It defaults to 'Matrix'. # #app_name: my_branded_matrix_server @@ -401,7 +438,76 @@ class EmailConfig(Config): # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates # #template_dir: "res/templates" + + # Subjects to use when sending emails from Synapse. + # + # The placeholder '%%(app)s' will be replaced with the value of the 'app_name' + # setting above, or by a value dictated by the Matrix client application. + # + # If a subject isn't overridden in this configuration file, the value used as + # its example will be used. + # + #subjects: + + # Subjects for notification emails. + # + # On top of the '%%(app)s' placeholder, these can use the following + # placeholders: + # + # * '%%(person)s', which will be replaced by the display name of the user(s) + # that sent the message(s), e.g. "Alice and Bob". + # * '%%(room)s', which will be replaced by the name of the room the + # message(s) have been sent to, e.g. "My super room". + # + # See the example provided for each setting to see which placeholder can be + # used and how to use them. + # + # Subject to use to notify about one message from one or more user(s) in a + # room which has a name. + #message_from_person_in_room: "%(message_from_person_in_room)s" + # + # Subject to use to notify about one message from one or more user(s) in a + # room which doesn't have a name. + #message_from_person: "%(message_from_person)s" + # + # Subject to use to notify about multiple messages from one or more users in + # a room which doesn't have a name. + #messages_from_person: "%(messages_from_person)s" + # + # Subject to use to notify about multiple messages in a room which has a + # name. + #messages_in_room: "%(messages_in_room)s" + # + # Subject to use to notify about multiple messages in multiple rooms. + #messages_in_room_and_others: "%(messages_in_room_and_others)s" + # + # Subject to use to notify about multiple messages from multiple persons in + # multiple rooms. This is similar to the setting above except it's used when + # the room in which the notification was triggered has no name. + #messages_from_person_and_others: "%(messages_from_person_and_others)s" + # + # Subject to use to notify about an invite to a room which has a name. + #invite_from_person_to_room: "%(invite_from_person_to_room)s" + # + # Subject to use to notify about an invite to a room which doesn't have a + # name. + #invite_from_person: "%(invite_from_person)s" + + # Subject for emails related to account administration. + # + # On top of the '%%(app)s' placeholder, these one can use the + # '%%(server_name)s' placeholder, which will be replaced by the value of the + # 'server_name' setting in your Synapse configuration. + # + # Subject to use when sending a password reset email. + #password_reset: "%(password_reset)s" + # + # Subject to use when sending a verification email to assert an address's + # ownership. + #email_validation: "%(email_validation)s" """ + % DEFAULT_SUBJECTS + ) class ThreepidBehaviour(Enum): diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index a10dba0af6..af117fddf9 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -27,6 +27,7 @@ import jinja2 from synapse.api.constants import EventTypes from synapse.api.errors import StoreError +from synapse.config.emailconfig import EmailSubjectConfig from synapse.logging.context import make_deferred_yieldable from synapse.push.presentable_names import ( calculate_room_name, @@ -42,23 +43,6 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -MESSAGE_FROM_PERSON_IN_ROOM = ( - "You have a message on %(app)s from %(person)s in the %(room)s room..." -) -MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." -MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." -MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..." -MESSAGES_IN_ROOM_AND_OTHERS = ( - "You have messages on %(app)s in the %(room)s room and others..." -) -MESSAGES_FROM_PERSON_AND_OTHERS = ( - "You have messages on %(app)s from %(person)s and others..." -) -INVITE_FROM_PERSON_TO_ROOM = ( - "%(person)s has invited you to join the %(room)s room on %(app)s..." -) -INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." - CONTEXT_BEFORE = 1 CONTEXT_AFTER = 1 @@ -121,6 +105,7 @@ class Mailer(object): self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() self.app_name = app_name + self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig logger.info("Created Mailer for app_name %s" % app_name) @@ -147,7 +132,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Password Reset" % self.hs.config.server_name, + self.email_subjects.password_reset + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -174,7 +160,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Register your Email Address" % self.hs.config.server_name, + self.email_subjects.email_validation + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -202,7 +189,8 @@ class Mailer(object): await self.send_email( email_address, - "[%s] Validate Your Email" % self.hs.config.server_name, + self.email_subjects.email_validation + % {"server_name": self.hs.config.server_name}, template_vars, ) @@ -273,9 +261,7 @@ class Mailer(object): "reason": reason, } - await self.send_email( - email_address, "[%s] %s" % (self.app_name, summary_text), template_vars - ) + await self.send_email(email_address, summary_text, template_vars) async def send_email(self, email_address, subject, extra_template_vars): """Send an email with the given information and template text""" @@ -482,12 +468,12 @@ class Mailer(object): inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - return INVITE_FROM_PERSON % { + return self.email_subjects.invite_from_person % { "person": inviter_name, "app": self.app_name, } else: - return INVITE_FROM_PERSON_TO_ROOM % { + return self.email_subjects.invite_from_person_to_room % { "person": inviter_name, "room": room_name, "app": self.app_name, @@ -505,13 +491,13 @@ class Mailer(object): sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - return MESSAGE_FROM_PERSON_IN_ROOM % { + return self.email_subjects.message_from_person_in_room % { "person": sender_name, "room": room_name, "app": self.app_name, } elif sender_name is not None: - return MESSAGE_FROM_PERSON % { + return self.email_subjects.message_from_person % { "person": sender_name, "app": self.app_name, } @@ -519,7 +505,10 @@ class Mailer(object): # There's more than one notification for this room, so just # say there are several if room_name is not None: - return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name} + return self.email_subjects.messages_in_room % { + "room": room_name, + "app": self.app_name, + } else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" @@ -537,7 +526,7 @@ class Mailer(object): ] ) - return MESSAGES_FROM_PERSON % { + return self.email_subjects.messages_from_person % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, } @@ -546,7 +535,7 @@ class Mailer(object): # ...but we still refer to the 'reason' room which triggered the mail if reason["room_name"] is not None: - return MESSAGES_IN_ROOM_AND_OTHERS % { + return self.email_subjects.messages_in_room_and_others % { "room": reason["room_name"], "app": self.app_name, } @@ -566,7 +555,7 @@ class Mailer(object): [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids] ) - return MESSAGES_FROM_PERSON_AND_OTHERS % { + return self.email_subjects.messages_from_person_and_others % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, } -- cgit 1.5.1 From 649a7ead5c4bd2d8b7c486ac1a68ce4e41d49767 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jul 2020 14:06:28 +0100 Subject: Add ability to run multiple pusher instances (#7855) This reuses the same scheme as federation sender sharding --- changelog.d/7855.feature | 1 + synapse/config/_base.py | 38 +++- synapse/config/_base.pyi | 5 + synapse/config/federation.py | 37 +--- synapse/config/push.py | 5 +- synapse/federation/sender/__init__.py | 16 +- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/push/pusherpool.py | 78 +++++---- tests/replication/test_pusher_shard.py | 193 +++++++++++++++++++++ 9 files changed, 293 insertions(+), 82 deletions(-) create mode 100644 changelog.d/7855.feature create mode 100644 tests/replication/test_pusher_shard.py (limited to 'synapse/push') diff --git a/changelog.d/7855.feature b/changelog.d/7855.feature new file mode 100644 index 0000000000..2b6a9f0e71 --- /dev/null +++ b/changelog.d/7855.feature @@ -0,0 +1 @@ +Add experimental support for running multiple pusher workers. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1391e5fc43..fd137853b1 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -19,9 +19,11 @@ import argparse import errno import os from collections import OrderedDict +from hashlib import sha256 from textwrap import dedent -from typing import Any, MutableMapping, Optional +from typing import Any, List, MutableMapping, Optional +import attr import yaml @@ -717,4 +719,36 @@ def find_config_files(search_paths): return config_files -__all__ = ["Config", "RootConfig"] +@attr.s +class ShardedWorkerHandlingConfig: + """Algorithm for choosing which instance is responsible for handling some + sharded work. + + For example, the federation senders use this to determine which instances + handles sending stuff to a given destination (which is used as the `key` + below). + """ + + instances = attr.ib(type=List[str]) + + def should_handle(self, instance_name: str, key: str) -> bool: + """Whether this instance is responsible for handling the given key. + """ + + # If multiple instances are not defined we always return true. + if not self.instances or len(self.instances) == 1: + return True + + # We shard by taking the hash, modulo it by the number of instances and + # then checking whether this instance matches the instance at that + # index. + # + # (Technically this introduces some bias and is not entirely uniform, + # but since the hash is so large the bias is ridiculously small). + dest_hash = sha256(key.encode("utf8")).digest() + dest_int = int.from_bytes(dest_hash, byteorder="little") + remainder = dest_int % (len(self.instances)) + return self.instances[remainder] == instance_name + + +__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 9e576060d4..eb911e8f9f 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -137,3 +137,8 @@ class Config: def read_config_files(config_files: List[str]): ... def find_config_files(search_paths: List[str]): ... + +class ShardedWorkerHandlingConfig: + instances: List[str] + def __init__(self, instances: List[str]) -> None: ... + def should_handle(self, instance_name: str, key: str) -> bool: ... diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 7782ab4c9d..82ff9664de 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -13,42 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hashlib import sha256 -from typing import List, Optional +from typing import Optional -import attr from netaddr import IPSet -from ._base import Config, ConfigError - - -@attr.s -class ShardedFederationSendingConfig: - """Algorithm for choosing which federation sender instance is responsible - for which destionation host. - """ - - instances = attr.ib(type=List[str]) - - def should_send_to(self, instance_name: str, destination: str) -> bool: - """Whether this instance is responsible for sending transcations for - the given host. - """ - - # If multiple federation senders are not defined we always return true. - if not self.instances or len(self.instances) == 1: - return True - - # We shard by taking the hash, modulo it by the number of federation - # senders and then checking whether this instance matches the instance - # at that index. - # - # (Technically this introduces some bias and is not entirely uniform, but - # since the hash is so large the bias is ridiculously small). - dest_hash = sha256(destination.encode("utf8")).digest() - dest_int = int.from_bytes(dest_hash, byteorder="little") - remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name +from ._base import Config, ConfigError, ShardedWorkerHandlingConfig class FederationConfig(Config): @@ -61,7 +30,7 @@ class FederationConfig(Config): self.send_federation = config.get("send_federation", True) federation_sender_instances = config.get("federation_sender_instances") or [] - self.federation_shard_config = ShardedFederationSendingConfig( + self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances ) diff --git a/synapse/config/push.py b/synapse/config/push.py index 6f2b3a7faa..a1f3752c8a 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config +from ._base import Config, ShardedWorkerHandlingConfig class PushConfig(Config): @@ -24,6 +24,9 @@ class PushConfig(Config): push_config = config.get("push", {}) self.push_include_content = push_config.get("include_content", True) + pusher_instances = config.get("pusher_instances") or [] + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) + # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4b63a0755f..b328a4df09 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -197,7 +197,7 @@ class FederationSender(object): destinations = { d for d in destinations - if self._federation_shard_config.should_send_to( + if self._federation_shard_config.should_handle( self._instance_name, d ) } @@ -335,7 +335,7 @@ class FederationSender(object): d for d in domains if d != self.server_name - and self._federation_shard_config.should_send_to(self._instance_name, d) + and self._federation_shard_config.should_handle(self._instance_name, d) ] if not domains: return @@ -441,7 +441,7 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): continue @@ -460,7 +460,7 @@ class FederationSender(object): if destination == self.server_name: continue - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): continue @@ -486,7 +486,7 @@ class FederationSender(object): logger.info("Not sending EDU to ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return @@ -507,7 +507,7 @@ class FederationSender(object): edu: edu to send key: clobbering key for this edu """ - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, edu.destination ): return @@ -523,7 +523,7 @@ class FederationSender(object): logger.warning("Not sending device update to ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return @@ -541,7 +541,7 @@ class FederationSender(object): logger.warning("Not waking up ourselves") return - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): return diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 6402136e8a..3436741783 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -78,7 +78,7 @@ class PerDestinationQueue(object): self._federation_shard_config = hs.config.federation.federation_shard_config self._should_send_on_this_instance = True - if not self._federation_shard_config.should_send_to( + if not self._federation_shard_config.should_handle( self._instance_name, destination ): # We don't raise an exception here to avoid taking out any other diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index f6a5458681..2456f12f46 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -15,13 +15,12 @@ # limitations under the License. import logging -from collections import defaultdict -from threading import Lock -from typing import Dict, Tuple, Union +from typing import TYPE_CHECKING, Dict, Union + +from prometheus_client import Gauge from twisted.internet import defer -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher from synapse.push.pusher import PusherFactory from synapse.util.async_helpers import concurrently_execute +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) +synapse_pushers = Gauge( + "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"] +) + + class PusherPool: """ The pusher pool. This is responsible for dispatching notifications of new events to @@ -47,36 +55,20 @@ class PusherPool: Pusher.on_new_receipts are not expected to return deferreds. """ - def __init__(self, _hs): - self.hs = _hs - self.pusher_factory = PusherFactory(_hs) - self._should_start_pushers = _hs.config.start_pushers + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.pusher_factory = PusherFactory(hs) + self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + # We shard the handling of push notifications by user ID. + self._pusher_shard_config = hs.config.push.pusher_shard_config + self._instance_name = hs.get_instance_name() + # map from user id to app_id:pushkey to pusher self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]] - # a lock for the pushers dict, since `count_pushers` is called from an different - # and we otherwise get concurrent modification errors - self._pushers_lock = Lock() - - def count_pushers(): - results = defaultdict(int) # type: Dict[Tuple[str, str], int] - with self._pushers_lock: - for pushers in self.pushers.values(): - for pusher in pushers.values(): - k = (type(pusher).__name__, pusher.app_id) - results[k] += 1 - return results - - LaterGauge( - name="synapse_pushers", - desc="the number of active pushers", - labels=["kind", "app_id"], - caller=count_pushers, - ) - def start(self): """Starts the pushers off in a background process. """ @@ -104,6 +96,7 @@ class PusherPool: Returns: Deferred[EmailPusher|HttpPusher] """ + time_now_msec = self.clock.time_msec() # we try to create the pusher just to validate the config: it @@ -176,6 +169,9 @@ class PusherPool: access_tokens (Iterable[int]): access token *ids* to remove pushers for """ + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + tokens = set(access_tokens) for p in (yield self.store.get_pushers_by_user_id(user_id)): if p["access_token"] in tokens: @@ -237,6 +233,9 @@ class PusherPool: if not self._should_start_pushers: return + if not self._pusher_shard_config.should_handle(self._instance_name, user_id): + return + resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None @@ -275,6 +274,11 @@ class PusherPool: Returns: Deferred[EmailPusher|HttpPusher] """ + if not self._pusher_shard_config.should_handle( + self._instance_name, pusherdict["user_name"] + ): + return + try: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: @@ -298,11 +302,12 @@ class PusherPool: appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) - with self._pushers_lock: - byuser = self.pushers.setdefault(pusherdict["user_name"], {}) - if appid_pushkey in byuser: - byuser[appid_pushkey].on_stop() - byuser[appid_pushkey] = p + byuser = self.pushers.setdefault(pusherdict["user_name"], {}) + if appid_pushkey in byuser: + byuser[appid_pushkey].on_stop() + byuser[appid_pushkey] = p + + synapse_pushers.labels(type(p).__name__, p.app_id).inc() # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to @@ -330,9 +335,10 @@ class PusherPool: if appid_pushkey in byuser: logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - byuser[appid_pushkey].on_stop() - with self._pushers_lock: - del byuser[appid_pushkey] + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() yield self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py new file mode 100644 index 0000000000..2bdc6edbb1 --- /dev/null +++ b/tests/replication/test_pusher_shard.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + + +class PusherShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks pusher sharding works + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def default_config(self): + conf = super().default_config() + conf["start_pushers"] = False + return conf + + def _create_pusher_and_send_msg(self, localpart): + # Create a user that will get push notifications + user_id = self.register_user(localpart, "pass") + access_token = self.login(localpart, "pass") + + # Register a pusher + user_dict = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_dict["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "https://push.example.com/push"}, + ) + ) + + self.pump() + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + response = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = response["event_id"] + + return event_id + + def test_send_push_single_worker(self): + """Test that registration works when using a pusher worker. + """ + http_client_mock = Mock(spec_set=["post_json_get_json"]) + http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + {"start_pushers": True}, + proxied_http_client=http_client_mock, + ) + + event_id = self._create_pusher_and_send_msg("user") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + def test_send_push_multiple_workers(self): + """Test that registration works when using sharded pusher workers. + """ + http_client_mock1 = Mock(spec_set=["post_json_get_json"]) + http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher1", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock1, + ) + + http_client_mock2 = Mock(spec_set=["post_json_get_json"]) + http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed( + {} + ) + + self.make_worker_hs( + "synapse.app.pusher", + { + "start_pushers": True, + "worker_name": "pusher2", + "pusher_instances": ["pusher1", "pusher2"], + }, + proxied_http_client=http_client_mock2, + ) + + # We choose a user name that we know should go to pusher1. + event_id = self._create_pusher_and_send_msg("user2") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_called_once() + http_client_mock2.post_json_get_json.assert_not_called() + self.assertEqual( + http_client_mock1.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock1.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) + + http_client_mock1.post_json_get_json.reset_mock() + http_client_mock2.post_json_get_json.reset_mock() + + # Now we choose a user name that we know should go to pusher2. + event_id = self._create_pusher_and_send_msg("user4") + + # Advance time a bit, so the pusher will register something has happened + self.pump() + + http_client_mock1.post_json_get_json.assert_not_called() + http_client_mock2.post_json_get_json.assert_called_once() + self.assertEqual( + http_client_mock2.post_json_get_json.call_args[0][0], + "https://push.example.com/push", + ) + self.assertEqual( + event_id, + http_client_mock2.post_json_get_json.call_args[0][1]["notification"][ + "event_id" + ], + ) -- cgit 1.5.1 From b975fa2e9952f1f8ac2cddb15c287768bf9b0b4e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 10:59:51 -0400 Subject: Convert state resolution to async/await (#7942) --- changelog.d/7942.misc | 1 + synapse/api/auth.py | 12 ++- synapse/events/builder.py | 4 +- synapse/federation/sender/__init__.py | 4 +- synapse/handlers/presence.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/state/__init__.py | 95 ++++++++---------- synapse/state/v1.py | 15 ++- synapse/state/v2.py | 107 ++++++++++----------- synapse/storage/data_stores/main/push_rule.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/storage/data_stores/main/user_directory.py | 4 +- synapse/storage/persist_events.py | 5 +- tests/federation/test_federation_sender.py | 19 ++-- tests/state/test_v2.py | 17 ++-- tests/storage/test_room.py | 8 +- tests/test_state.py | 72 ++++++++------ tests/test_utils/__init__.py | 7 +- 18 files changed, 198 insertions(+), 184 deletions(-) create mode 100644 changelog.d/7942.misc (limited to 'synapse/push') diff --git a/changelog.d/7942.misc b/changelog.d/7942.misc new file mode 100644 index 0000000000..b504cf4e6f --- /dev/null +++ b/changelog.d/7942.misc @@ -0,0 +1 @@ +Convert state resolution to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 40dc62ef6c..b53e8451e5 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -127,8 +127,10 @@ class Auth(object): if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id + member = yield defer.ensureDeferred( + self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id + ) ) membership = member.membership if member else None @@ -665,8 +667,10 @@ class Auth(object): ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" + visibility = yield defer.ensureDeferred( + self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) ) if ( visibility diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 92aadfe7ef..0bb216419a 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -106,8 +106,8 @@ class EventBuilder(object): Deferred[FrozenEvent] """ - state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids + state_ids = yield defer.ensureDeferred( + self._state.get_current_state_ids(self.room_id, prev_event_ids) ) auth_ids = yield self._auth.compute_auth_events(self, state_ids) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 99ce73e081..ba4ddd2370 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -330,7 +330,9 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield self.state.get_current_hosts_in_room(room_id) + domains = yield defer.ensureDeferred( + self.state.get_current_hosts_in_room(room_id) + ) domains = [ d for d in domains diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8e99c83d9d..b3a3bb8c3f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. - user_ids = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) + users = await self.state.get_current_users_in_room(room_id) + user_ids = list(filter(self.is_mine_id, users)) states_d = await self.current_state_for_users(user_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 43ffe6faf0..472ddf9f7d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -304,7 +304,9 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred( + context.get_current_state_ids() + ) push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 495d9f04c8..25ccef5aa5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,14 +16,12 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Set +from typing import Awaitable, Dict, Iterable, List, Optional, Set import attr from frozendict import frozendict from prometheus_client import Histogram -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase @@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.roommember import ProfileInfo from synapse.types import StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -108,8 +107,7 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def get_current_state( + async def get_current_state( self, room_id, event_type=None, state_key="", latest_event_ids=None ): """ Retrieves the current state for the room. This is done by @@ -126,20 +124,20 @@ class StateHandler(object): map from (type, state_key) to event """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state if event_type: event_id = state.get((event_type, state_key)) event = None if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) return event - state_map = yield self.store.get_events( + state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) state = { @@ -148,8 +146,7 @@ class StateHandler(object): return state - @defer.inlineCallbacks - def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids(self, room_id, latest_event_ids=None): """Get the current state, or the state at a set of events, for a room Args: @@ -164,41 +161,38 @@ class StateHandler(object): (event_type, state_key) -> event_id """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state return state - @defer.inlineCallbacks - def get_current_users_in_room(self, room_id, latest_event_ids=None): + async def get_current_users_in_room( + self, room_id: str, latest_event_ids: Optional[List[str]] = None + ) -> Dict[str, ProfileInfo]: """ Get the users who are currently in a room. Args: - room_id (str): The ID of the room. - latest_event_ids (List[str]|None): Precomputed list of latest - event IDs. Will be computed if None. + room_id: The ID of the room. + latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their - profileinfo. + Dictionary of user IDs to their profileinfo. """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_users_in_room") - entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + joined_users = await self.store.get_joined_users_from_state(room_id, entry) return joined_users - @defer.inlineCallbacks - def get_current_hosts_in_room(self, room_id): - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - return (yield self.get_hosts_in_room_at_events(room_id, event_ids)) + async def get_current_hosts_in_room(self, room_id): + event_ids = await self.store.get_latest_event_ids_in_room(room_id) + return await self.get_hosts_in_room_at_events(room_id, event_ids) - @defer.inlineCallbacks - def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events(self, room_id, event_ids): """Get the hosts that were in a room at the given event ids Args: @@ -208,12 +202,11 @@ class StateHandler(object): Returns: Deferred[list[str]]: the hosts in the room at the given events """ - entry = yield self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = yield self.store.get_joined_hosts(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, event_ids) + joined_hosts = await self.store.get_joined_hosts(room_id, entry) return joined_hosts - @defer.inlineCallbacks - def compute_event_context( + async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None ): """Build an EventContext structure for the event. @@ -278,7 +271,7 @@ class StateHandler(object): # otherwise, we'll need to resolve the state across the prev_events. logger.debug("calling resolve_state_groups from compute_event_context") - entry = yield self.resolve_state_groups_for_events( + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -295,7 +288,7 @@ class StateHandler(object): # if not state_group_before_event: - state_group_before_event = yield self.state_store.store_state_group( + state_group_before_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event_prev_group, @@ -335,7 +328,7 @@ class StateHandler(object): state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = yield self.state_store.store_state_group( + state_group_after_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event, @@ -353,8 +346,7 @@ class StateHandler(object): ) @measure_func() - @defer.inlineCallbacks - def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -373,7 +365,7 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.state_store.get_state_groups_ids( + state_groups_ids = await self.state_store.get_state_groups_ids( room_id, event_ids ) @@ -382,7 +374,7 @@ class StateHandler(object): elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) + prev_group, delta_ids = await self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, @@ -391,9 +383,9 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) - result = yield self._state_resolution_handler.resolve_state_groups( + result = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups_ids, @@ -402,8 +394,7 @@ class StateHandler(object): ) return result - @defer.inlineCallbacks - def resolve_events(self, room_version, state_sets, event): + async def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -414,7 +405,7 @@ class StateHandler(object): state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, event.room_id, room_version, @@ -451,9 +442,8 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - @defer.inlineCallbacks @log_function - def resolve_state_groups( + async def resolve_state_groups( self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -479,13 +469,13 @@ class StateResolutionHandler(object): state_res_store (StateResolutionStore) Returns: - Deferred[_StateCacheEntry]: resolved state + _StateCacheEntry: resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) - with (yield self.resolve_linearizer.queue(group_names)): + with (await self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: @@ -517,7 +507,7 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, room_id, room_version, @@ -598,7 +588,7 @@ def resolve_events_with_store( state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", -): +) -> Awaitable[StateMap[str]]: """ Args: room_id: the room we are working in @@ -619,8 +609,7 @@ def resolve_events_with_store( state_res_store: a place to fetch events from Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 7b531a8337..ab5e24841d 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,9 +15,7 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional - -from twisted.internet import defer +from typing import Awaitable, Callable, Dict, List, Optional from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,12 +30,11 @@ logger = logging.getLogger(__name__) POWER_KEY = (EventTypes.PowerLevels, "") -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( room_id: str, state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable, + state_map_factory: Callable[[List[str]], Awaitable], ): """ Args: @@ -56,7 +53,7 @@ def resolve_events_with_store( state_map_factory: will be called with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + an Awaitable that resolves to a dict of event_id to event. Returns: Deferred[dict[(str, str), str]]: @@ -80,7 +77,7 @@ def resolve_events_with_store( # dict[str, FrozenEvent]: a map from state event id to event. Only includes # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) + state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -110,7 +107,7 @@ def resolve_events_with_store( "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) - state_map_new = yield state_map_factory(new_needed_events) + state_map_new = await state_map_factory(new_needed_events) for event in state_map_new.values(): if event.room_id != room_id: raise Exception( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index bf6caa0946..6634955cdc 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -18,8 +18,6 @@ import itertools import logging from typing import Dict, List, Optional -from twisted.internet import defer - import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,14 +30,13 @@ from synapse.util import Clock logger = logging.getLogger(__name__) -# We want to yield to the reactor occasionally during state res when dealing +# We want to await to the reactor occasionally during state res when dealing # with large data sets, so that we don't exhaust the reactor. This is done by -# yielding to reactor during loops every N iterations. -_YIELD_AFTER_ITERATIONS = 100 +# awaiting to reactor during loops every N iterations. +_AWAIT_AFTER_ITERATIONS = 100 -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, @@ -87,7 +84,7 @@ def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) full_conflicted_set = set( itertools.chain( @@ -95,7 +92,7 @@ def resolve_events_with_store( ) ) - events = yield state_res_store.get_events( + events = await state_res_store.get_events( [eid for eid in full_conflicted_set if eid not in event_map], allow_rejected=True, ) @@ -118,14 +115,14 @@ def resolve_events_with_store( eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) - sorted_power_events = yield _reverse_topological_power_sort( + sorted_power_events = await _reverse_topological_power_sort( clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -148,13 +145,13 @@ def resolve_events_with_store( logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) - leftover_events = yield _mainline_sort( + leftover_events = await _mainline_sort( clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -174,8 +171,7 @@ def resolve_events_with_store( return resolved_state -@defer.inlineCallbacks -def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. @@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(room_id, event_id, event_map, state_res_store) + event = await _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): @@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): return int(level) -@defer.inlineCallbacks -def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference(state_sets, event_map, state_res_store): """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. @@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Deferred[set[str]]: Set of event IDs """ - difference = yield state_res_store.get_auth_chain_difference( + difference = await state_res_store.get_auth_chain_difference( [set(state_set.values()) for state_set in state_sets] ) @@ -292,8 +287,7 @@ def _is_power_event(event): return False -@defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph( +async def _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event @@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(room_id, eid, event_map, state_res_store) + event = await _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph( graph.setdefault(eid, set()).add(aid) -@defer.inlineCallbacks -def _reverse_topological_power_sort( +async def _reverse_topological_power_sort( clock, room_id, event_ids, event_map, state_res_store, auth_diff ): """Returns a list of the event_ids sorted by reverse topological ordering, @@ -344,26 +337,26 @@ def _reverse_topological_power_sort( graph = {} for idx, event_id in enumerate(event_ids, start=1): - yield _add_event_and_auth_chain_to_graph( + await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_to_pl = {} for idx, event_id in enumerate(graph, start=1): - pl = yield _get_power_level_for_sender( + pl = await _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) def _get_power_order(event_id): ev = event_map[event_id] @@ -378,8 +371,7 @@ def _reverse_topological_power_sort( return sorted_events -@defer.inlineCallbacks -def _iterative_auth_checks( +async def _iterative_auth_checks( clock, room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the @@ -405,7 +397,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) @@ -420,7 +412,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(room_id, ev_id, event_map, state_res_store) + ev = await _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -438,16 +430,15 @@ def _iterative_auth_checks( except AuthError: pass - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) return resolved_state -@defer.inlineCallbacks -def _mainline_sort( +async def _mainline_sort( clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store ): """Returns a sorted list of event_ids sorted by mainline ordering based on @@ -474,21 +465,21 @@ def _mainline_sort( idx = 0 while pl: mainline.append(pl) - pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) + pl_ev = await _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) idx += 1 @@ -498,23 +489,24 @@ def _mainline_sort( order_map = {} for idx, ev_id in enumerate(event_ids, start=1): - depth = yield _get_mainline_depth_for_event( + depth = await _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids -@defer.inlineCallbacks -def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): +async def _get_mainline_depth_for_event( + event, mainline_map, event_map, state_res_store +): """Get the mainline depths for the given event based on the mainline map Args: @@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor return 0 -@defer.inlineCallbacks -def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): """Helper function to look up event in event_map, falling back to looking it up in the store @@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): Deferred[Optional[FrozenEvent]] """ if event_id not in event_map: - events = yield state_res_store.get_events([event_id], allow_rejected=True) + events = await state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) event = event_map.get(event_id) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index d181488db7..c229248101 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -259,7 +259,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 29765890ee..a92e401e88 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 6b8130bf0f..942e51fd3a 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): room_id ) - users_with_profile = yield state.get_current_users_in_room(room_id) + users_with_profile = yield defer.ensureDeferred( + state.get_current_users_in_room(room_id) + ) user_ids = set(users_with_profile) # Update each user in the user directory. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa46041676..78fbdcdee8 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,7 +29,6 @@ from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main.events import DeltaState from synapse.types import StateMap @@ -648,6 +647,10 @@ class EventsPersistenceStorage(object): room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 1a9bd5f37d..d1bd18da39 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -26,21 +26,24 @@ from synapse.rest import admin from synapse.rest.client.v1 import login from synapse.types import JsonDict, ReadReceipt +from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): + mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) + # Ensure a new Awaitable is created for each call. + mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + ["test", "host2"] + ) return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), + state_handler=mock_state_handler, federation_transport_client=Mock(spec=["send_transaction"]), ) @override_config({"send_federation": True}) def test_send_receipts(self): - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - # stub out get_current_hosts_in_room - mock_state_handler = hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - # stub out get_users_who_share_room_with_user so that it claims that # `@user2:host2` is in the room def get_users_who_share_room_with_user(user_id): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 38f9b423ef..f2955a9c69 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +from typing import List import attr @@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(event_map), ) - state_before = self.successResultOf(state_d) + state_before = self.successResultOf(defer.ensureDeferred(state_d)) state_after = dict(state_before) if fake_event.state_key is not None: @@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(self.event_map), ) - state = self.successResultOf(state_d) + state = self.successResultOf(defer.ensureDeferred(state_d)) self.assert_dict(self.expected_combined_state, state) @@ -608,9 +609,11 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + return defer.succeed( + {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + ) - def _get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids: List[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -622,10 +625,10 @@ class TestStateResolutionStore(object): presence of rejected events Args: - event_ids (list): The event IDs of the events to fetch the auth + event_ids: The event IDs of the events to fetch the auth chain for. Must be state events. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + List of event IDs of the auth chain. """ # Simple DFS for auth chain @@ -648,4 +651,4 @@ class TestStateResolutionStore(object): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) - return set(chains[0]).union(*chains[1:]) - common + return defer.succeed(set(chains[0]).union(*chains[1:]) - common) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index b1dceb2918..1d77b4a2d6 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( @@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( diff --git a/tests/test_state.py b/tests/test_state.py index 66f22f6813..4858e8fc59 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -97,17 +97,19 @@ class StateGroupStore(object): self._group_to_state[state_group] = dict(current_state_ids) - return state_group + return defer.succeed(state_group) def get_events(self, event_ids, **kwargs): - return { - e_id: self._event_id_to_event[e_id] - for e_id in event_ids - if e_id in self._event_id_to_event - } + return defer.succeed( + { + e_id: self._event_id_to_event[e_id] + for e_id in event_ids + if e_id in self._event_id_to_event + } + ) def get_state_group_delta(self, name): - return None, None + return defer.succeed((None, None)) def register_events(self, events): for e in events: @@ -120,7 +122,7 @@ class StateGroupStore(object): self._event_to_state_group[event_id] = state_group def get_room_version_id(self, room_id): - return RoomVersions.V1.identifier + return defer.succeed(RoomVersions.V1.identifier) class DictObj(dict): @@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase): context_store = {} # type: dict[str, EventContext] for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual( {e.event_id for e in old_state}, set(current_state_ids.values()) @@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) prev_state_ids = yield context.get_prev_state_ids() @@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) + @defer.inlineCallbacks def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = self.store.store_state_group( + sg1 = yield self.store.store_state_group( prev_event_id_1, event.room_id, None, @@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = self.store.store_state_group( + sg2 = yield self.store.store_state_group( prev_event_id_2, event.room_id, None, @@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_2, sg2) - return self.state.compute_event_context(event) + result = yield defer.ensureDeferred(self.state.compute_event_context(event)) + return result diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 7b345b03bb..508aeba078 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,7 +17,7 @@ """ Utilities for running the unit tests """ -from typing import Awaitable, TypeVar +from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: # if next didn't raise, the awaitable hasn't completed. raise Exception("awaitable has not yet completed") + + +async def make_awaitable(result: Any): + """Create an awaitable that just returns a result.""" + return result -- cgit 1.5.1 From 8144bc26a7432463b7e70f9c03198d4724952522 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 12:21:34 -0400 Subject: Convert push to async/await. (#7948) --- changelog.d/7948.misc | 1 + synapse/push/action_generator.py | 7 +-- synapse/push/bulk_push_rule_evaluator.py | 62 ++++++++----------- synapse/push/httppusher.py | 58 ++++++++---------- synapse/push/presentable_names.py | 15 ++--- synapse/push/push_tools.py | 22 +++---- synapse/push/pusherpool.py | 70 +++++++++------------- .../storage/data_stores/main/event_push_actions.py | 4 +- tests/replication/slave/storage/test_events.py | 6 +- tests/storage/test_event_push_actions.py | 6 +- 10 files changed, 106 insertions(+), 145 deletions(-) create mode 100644 changelog.d/7948.misc (limited to 'synapse/push') diff --git a/changelog.d/7948.misc b/changelog.d/7948.misc new file mode 100644 index 0000000000..7c2e2b18b7 --- /dev/null +++ b/changelog.d/7948.misc @@ -0,0 +1 @@ +Convert push to async/await. diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 1ffd5e2df3..0d23142653 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.metrics import Measure from .bulk_push_rule_evaluator import BulkPushRuleEvaluator @@ -37,7 +35,6 @@ class ActionGenerator(object): # event stream, so we just run the rules for a client with no profile # tag (ie. we just need all the users). - @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, context): + async def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user(event, context) + await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 472ddf9f7d..04b9d8ac82 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,6 @@ from collections import namedtuple from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY @@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object): resizable=False, ) - @defer.inlineCallbacks - def _get_rules_for_event(self, event, context): + async def _get_rules_for_event(self, event, context): """This gets the rules for all users in the room at the time of the event, as well as the push rules for the invitee if the event is an invite. @@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = yield self._get_rules_for_room(room_id) + rules_for_room = await self._get_rules_for_room(room_id) - rules_by_user = yield rules_for_room.get_rules(event, context) + rules_by_user = await rules_for_room.get_rules(event, context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = yield self.store.user_has_pusher(invited) + has_pusher = await self.store.user_has_pusher(invited) if has_pusher: rules_by_user = dict(rules_by_user) - rules_by_user[invited] = yield self.store.get_push_rules_for_user( + rules_by_user[invited] = await self.store.get_push_rules_for_user( invited ) @@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics, ) - @defer.inlineCallbacks - def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids() + async def _get_power_levels_and_sender_level(self, event, context): + prev_state_ids = await context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case - pl_event = yield self.store.get_event(pl_event_id) + pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = yield self.auth.compute_auth_events( + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level - @defer.inlineCallbacks - def action_for_event_by_user(self, event, context): + async def action_for_event_by_user(self, event, context) -> None: """Given an event and context, evaluate the push rules and insert the results into the event_push_actions_staging table. - - Returns: - Deferred """ - rules_by_user = yield self._get_rules_for_event(event, context) + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context(event, context) + room_members = await self.store.get_joined_users_from_context(event, context) ( power_levels, sender_power_level, - ) = yield self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels @@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object): continue if not event.is_state(): - is_ignored = yield self.store.is_ignored_by(event.sender, uid) + is_ignored = await self.store.is_ignored_by(event.sender, uid) if is_ignored: continue @@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -274,8 +266,7 @@ class RulesForRoom(object): # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) - @defer.inlineCallbacks - def get_rules(self, event, context): + async def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ @@ -286,7 +277,7 @@ class RulesForRoom(object): self.room_push_rule_cache_metrics.inc_hits() return self.rules_by_user - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() @@ -304,9 +295,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield defer.ensureDeferred( - context.get_current_state_ids() - ) + current_state_ids = await context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) @@ -353,7 +342,7 @@ class RulesForRoom(object): # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) - yield self._update_rules_with_member_event_ids( + await self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) else: @@ -371,8 +360,7 @@ class RulesForRoom(object): ) return ret_rules_by_user - @defer.inlineCallbacks - def _update_rules_with_member_event_ids( + async def _update_rules_with_member_event_ids( self, ret_rules_by_user, member_event_ids, state_group, event ): """Update the partially filled rules_by_user dict by fetching rules for @@ -388,7 +376,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) + rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} @@ -410,7 +398,7 @@ class RulesForRoom(object): logger.debug("Joined: %r", interested_in_user_ids) - if_users_with_pushers = yield self.store.get_if_users_have_pushers( + if_users_with_pushers = await self.store.get_if_users_have_pushers( interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) @@ -420,7 +408,7 @@ class RulesForRoom(object): logger.debug("With pushers: %r", user_ids) - users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + users_with_receipts = await self.store.get_users_with_read_receipts_in_room( self.room_id, on_invalidate=self.invalidate_all_cb ) @@ -431,7 +419,7 @@ class RulesForRoom(object): if uid in interested_in_user_ids: user_ids.add(uid) - rules_by_user = yield self.store.bulk_get_push_rules( + rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 2fac07593b..4c469efb20 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -17,7 +17,6 @@ import logging from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.api.constants import EventTypes @@ -128,12 +127,11 @@ class HttpPusher(object): # but currently that's the only type of receipt anyway... run_as_background_process("http_pusher.on_new_receipts", self._update_badge) - @defer.inlineCallbacks - def _update_badge(self): + async def _update_badge(self): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - yield self._send_badge(badge) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + await self._send_badge(badge) def on_timer(self): self._start_processing() @@ -152,8 +150,7 @@ class HttpPusher(object): run_as_background_process("httppush.process", self._process) - @defer.inlineCallbacks - def _process(self): + async def _process(self): # we should never get here if we are already processing assert not self._is_processing @@ -164,7 +161,7 @@ class HttpPusher(object): while True: starting_max_ordering = self.max_stream_ordering try: - yield self._unsafe_process() + await self._unsafe_process() except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: @@ -172,8 +169,7 @@ class HttpPusher(object): finally: self._is_processing = False - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): """ Looks for unset notifications and dispatch them, in order Never call this directly: use _process which will only allow this to @@ -181,7 +177,7 @@ class HttpPusher(object): """ fn = self.store.get_unread_push_actions_for_user_in_range_for_http - unprocessed = yield fn( + unprocessed = await fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -203,13 +199,13 @@ class HttpPusher(object): "app_display_name": self.app_display_name, }, ): - processed = yield self._process_one(push_action) + processed = await self._process_one(push_action) if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, self.user_id, @@ -224,14 +220,14 @@ class HttpPusher(object): if self.failing_since: self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) @@ -250,7 +246,7 @@ class HttpPusher(object): ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, @@ -263,7 +259,7 @@ class HttpPusher(object): return self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: @@ -276,18 +272,17 @@ class HttpPusher(object): ) break - @defer.inlineCallbacks - def _process_one(self, push_action): + async def _process_one(self, push_action): if "notify" not in push_action["actions"]: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action["event_id"], allow_none=True) + event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: return True # It's been redacted - rejected = yield self.dispatch_push(event, tweaks, badge) + rejected = await self.dispatch_push(event, tweaks, badge) if rejected is False: return False @@ -301,11 +296,10 @@ class HttpPusher(object): ) else: logger.info("Pushkey %s was rejected: removing", pk) - yield self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self.hs.remove_pusher(self.app_id, pk, self.user_id) return True - @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): + async def _build_notification_dict(self, event, tweaks, badge): priority = "low" if ( event.type == EventTypes.Encrypted @@ -335,7 +329,7 @@ class HttpPusher(object): } return d - ctx = yield push_tools.get_context_for_event( + ctx = await push_tools.get_context_for_event( self.storage, self.state_handler, event, self.user_id ) @@ -377,13 +371,12 @@ class HttpPusher(object): return d - @defer.inlineCallbacks - def dispatch_push(self, event, tweaks, badge): - notification_dict = yield self._build_notification_dict(event, tweaks, badge) + async def dispatch_push(self, event, tweaks, badge): + notification_dict = await self._build_notification_dict(event, tweaks, badge) if not notification_dict: return [] try: - resp = yield self.http_client.post_json_get_json( + resp = await self.http_client.post_json_get_json( self.url, notification_dict ) except Exception as e: @@ -400,8 +393,7 @@ class HttpPusher(object): rejected = resp["rejected"] return rejected - @defer.inlineCallbacks - def _send_badge(self, badge): + async def _send_badge(self, badge): """ Args: badge (int): number of unread messages @@ -424,7 +416,7 @@ class HttpPusher(object): } } try: - yield self.http_client.post_json_get_json(self.url, d) + await self.http_client.post_json_get_json(self.url, d) http_badges_processed_counter.inc() except Exception as e: logger.warning( diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0644a13cfc..d8f4a453cd 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$") ALL_ALONE = "Empty Room" -@defer.inlineCallbacks -def calculate_room_name( +async def calculate_room_name( store, room_state_ids, user_id, @@ -53,7 +50,7 @@ def calculate_room_name( """ # does it have a name? if (EventTypes.Name, "") in room_state_ids: - m_room_name = yield store.get_event( + m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: @@ -61,7 +58,7 @@ def calculate_room_name( # does it have a canonical alias? if (EventTypes.CanonicalAlias, "") in room_state_ids: - canon_alias = yield store.get_event( + canon_alias = await store.get_event( room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True ) if ( @@ -81,7 +78,7 @@ def calculate_room_name( my_member_event = None if (EventTypes.Member, user_id) in room_state_ids: - my_member_event = yield store.get_event( + my_member_event = await store.get_event( room_state_ids[(EventTypes.Member, user_id)], allow_none=True ) @@ -90,7 +87,7 @@ def calculate_room_name( and my_member_event.content["membership"] == "invite" ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: - inviter_member_event = yield store.get_event( + inviter_member_event = await store.get_event( room_state_ids[(EventTypes.Member, my_member_event.sender)], allow_none=True, ) @@ -107,7 +104,7 @@ def calculate_room_name( # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: - member_events = yield store.get_events( + member_events = await store.get_events( list(room_state_bytype_ids[EventTypes.Member].values()) ) all_members = [ diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 5dae4648c0..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage -@defer.inlineCallbacks -def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_local_user(user_id) - joins = yield store.get_rooms_for_user(user_id) +async def get_badge_count(store, user_id): + invites = await store.get_invited_rooms_for_local_user(user_id) + joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -32,7 +29,7 @@ def get_badge_count(store, user_id): if room_id in my_receipts_by_room: last_unread_event_id = my_receipts_by_room[room_id] - notifs = yield ( + notifs = await ( store.get_unread_event_push_actions_by_room_for_user( room_id, user_id, last_unread_event_id ) @@ -43,23 +40,22 @@ def get_badge_count(store, user_id): return badge -@defer.inlineCallbacks -def get_context_for_event(storage: Storage, state_handler, ev, user_id): +async def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) + room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room - name = yield calculate_room_name( + name = await calculate_room_name( storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield storage.main.get_event(sender_state_event_id) + sender_state_event = await storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2456f12f46..3c3262a88c 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union from prometheus_client import Gauge -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -52,7 +50,7 @@ class PusherPool: Note that it is expected that each pusher will have its own 'processing' loop which will send out the notifications in the background, rather than blocking until the notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and - Pusher.on_new_receipts are not expected to return deferreds. + Pusher.on_new_receipts are not expected to return awaitables. """ def __init__(self, hs: "HomeServer"): @@ -77,8 +75,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -94,7 +91,7 @@ class PusherPool: """Creates a new pusher and adds it to the pool Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ time_now_msec = self.clock.time_msec() @@ -124,9 +121,9 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() + last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() - yield self.store.add_pusher( + await self.store.add_pusher( user_id=user_id, access_token=access_token, kind=kind, @@ -140,15 +137,14 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) return pusher - @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user( + async def remove_pushers_by_app_id_and_pushkey_not_user( self, app_id, pushkey, not_user_id ): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: if p["user_name"] != not_user_id: logger.info( @@ -157,10 +153,9 @@ class PusherPool: pushkey, p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token(self, user_id, access_tokens): """Remove the pushers for a given user corresponding to a set of access_tokens. @@ -173,7 +168,7 @@ class PusherPool: return tokens = set(access_tokens) - for p in (yield self.store.get_pushers_by_user_id(user_id)): + for p in await self.store.get_pushers_by_user_id(user_id): if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", @@ -181,16 +176,15 @@ class PusherPool: p["pushkey"], p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, min_stream_id, max_stream_id): if not self.pushers: # nothing to do here. return try: - users_affected = yield self.store.get_push_action_users_in_range( + users_affected = await self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) @@ -202,8 +196,7 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): if not self.pushers: # nothing to do here. return @@ -211,7 +204,7 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - users_affected = yield self.store.get_users_sent_receipts_between( + users_affected = await self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) @@ -223,12 +216,11 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - @defer.inlineCallbacks - def start_pusher_by_id(self, app_id, pushkey, user_id): + async def start_pusher_by_id(self, app_id, pushkey, user_id): """Look up the details for the given pusher, and start it Returns: - Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any + EmailPusher|HttpPusher|None: The pusher started, if any """ if not self._should_start_pushers: return @@ -236,7 +228,7 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: @@ -245,34 +237,29 @@ class PusherPool: pusher = None if pusher_dict: - pusher = yield self._start_pusher(pusher_dict) + pusher = await self._start_pusher(pusher_dict) return pusher - @defer.inlineCallbacks - def _start_pushers(self): + async def _start_pushers(self) -> None: """Start all the pushers - - Returns: - Deferred """ - pushers = yield self.store.get_all_pushers() + pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. - yield concurrently_execute(self._start_pusher, pushers, 10) + await concurrently_execute(self._start_pusher, pushers, 10) logger.info("Started pushers") - @defer.inlineCallbacks - def _start_pusher(self, pusherdict): + async def _start_pusher(self, pusherdict): """Start the given pusher Args: pusherdict (dict): dict with the values pulled from the db table Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ if not self._pusher_shard_config.should_handle( self._instance_name, pusherdict["user_name"] @@ -315,7 +302,7 @@ class PusherPool: user_id = pusherdict["user_name"] last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: - have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( + have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering ) else: @@ -327,8 +314,7 @@ class PusherPool: return p - @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id, pushkey, user_id): appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) @@ -340,6 +326,6 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 504babaa7e..18297cf3b8 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): _get_if_maybe_push_in_range_for_user_txn, ) - def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging(self, event_id, user_id_actions): """Add the push actions for the event to the push action staging area. Args: @@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.db.runInteraction( + return await self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 1a88c7fb80..0b5204654c 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_handler = self.hs.get_state_handler() context = self.get_success(state_handler.compute_event_context(event)) - self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + self.get_success( + self.master_store.add_push_actions_to_staging( + event.event_id, {user_id: actions for user_id, actions in push_actions} + ) ) return event, context diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b45bc9c115..43dbeb42c5 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -72,8 +72,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + yield defer.ensureDeferred( + self.store.add_push_actions_to_staging( + event.event_id, {user_id: action} + ) ) yield self.store.db.runInteraction( "", -- cgit 1.5.1 From 9725c59247131d243316ff299e6864098d9bdc58 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 28 Jul 2020 19:20:55 +0100 Subject: Implement new experimental push rules with a database hack to enable them --- synapse/push/baserules.py | 217 ++++++++++++++++++++- synapse/storage/data_stores/main/push_rule.py | 35 +++- .../main/schema/delta/58/13new_push_rules_tmp.sql | 21 ++ 3 files changed, 259 insertions(+), 14 deletions(-) create mode 100644 synapse/storage/data_stores/main/schema/delta/58/13new_push_rules_tmp.sql (limited to 'synapse/push') diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 286374d0b5..e06b1a01e6 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -19,7 +19,7 @@ import copy from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP -def list_with_base_rules(rawrules): +def list_with_base_rules(rawrules, use_new_defaults=False): """Combine the list of rules set by the user with the default push rules Args: @@ -43,7 +43,9 @@ def list_with_base_rules(rawrules): ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) @@ -54,6 +56,7 @@ def list_with_base_rules(rawrules): make_base_append_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, + use_new_defaults, ) ) current_prio_class -= 1 @@ -62,6 +65,7 @@ def list_with_base_rules(rawrules): make_base_prepend_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, + use_new_defaults, ) ) @@ -70,27 +74,31 @@ def list_with_base_rules(rawrules): while current_prio_class > 0: ruleslist.extend( make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, + use_new_defaults, ) ) return ruleslist -def make_base_append_rules(kind, modified_base_rules): +def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": - rules = BASE_APPEND_OVERRIDE_RULES + rules = NEW_APPEND_OVERRIDE_RULES if use_new_defaults else BASE_APPEND_OVERRIDE_RULES elif kind == "underride": - rules = BASE_APPEND_UNDERRIDE_RULES + rules = NEW_APPEND_UNDERRIDE_RULES if use_new_defaults else BASE_APPEND_UNDERRIDE_RULES elif kind == "content": rules = BASE_APPEND_CONTENT_RULES @@ -105,11 +113,11 @@ def make_base_append_rules(kind, modified_base_rules): return rules -def make_base_prepend_rules(kind, modified_base_rules): +def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": - rules = BASE_PREPEND_OVERRIDE_RULES + rules = NEW_PREPEND_OVERRIDE_RULES if use_new_defaults else BASE_PREPEND_OVERRIDE_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) @@ -151,6 +159,16 @@ BASE_PREPEND_OVERRIDE_RULES = [ ] +NEW_PREPEND_OVERRIDE_RULES = [ + { + "rule_id": "global/override/.m.rule.master", + "enabled": False, + "conditions": [], + "actions": [], + } +] + + BASE_APPEND_OVERRIDE_RULES = [ { "rule_id": "global/override/.m.rule.suppress_notices", @@ -270,6 +288,141 @@ BASE_APPEND_OVERRIDE_RULES = [ ] +NEW_APPEND_OVERRIDE_RULES = [ + { + "rule_id": "global/override/.m.rule.encrypted", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.encrypted", + "_id": "_encrypted", + } + ], + "actions": ["notify"], + }, + { + "rule_id": "global/override/.m.rule.suppress_notices", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.message", + "_id": "_suppress_notices_type", + }, + { + "kind": "event_match", + "key": "content.msgtype", + "pattern": "m.notice", + "_id": "_suppress_notices", + } + ], + "actions": [], + }, + { + "rule_id": "global/underride/.m.rule.suppress_edits", + "conditions": [ + { + "kind": "event_match", + "key": "m.relates_to.m.rel_type", + "pattern": "m.replace", + "_id": "_suppress_edits", + } + ], + "actions": [], + }, + { + "rule_id": "global/override/.m.rule.invite_for_me", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.member", + "_id": "_member", + }, + { + "kind": "event_match", + "key": "content.membership", + "pattern": "invite", + "_id": "_invite_member", + }, + {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + ], + }, + { + "rule_id": "global/override/.m.rule.contains_display_name", + "conditions": [{"kind": "contains_display_name"}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + }, + { + "rule_id": "global/override/.m.rule.tombstone", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.room.tombstone", + "_id": "_tombstone", + }, + { + "kind": "event_match", + "key": "state_key", + "pattern": "", + "_id": "_tombstone_statekey", + }, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight"}, + ], + }, + { + "rule_id": "global/override/.m.rule.roomnotif", + "conditions": [ + { + "kind": "event_match", + "key": "content.body", + "pattern": "@room", + "_id": "_roomnotif_content", + }, + { + "kind": "sender_notification_permission", + "key": "room", + "_id": "_roomnotif_pl", + }, + ], + "actions": [ + "notify", + {"set_tweak": "highlight"}, + {"set_tweak": "sound", "value": "default"}, + ], + }, + { + "rule_id": "global/override/.m.rule.call", + "conditions": [ + { + "kind": "event_match", + "key": "type", + "pattern": "m.call.invite", + "_id": "_call", + } + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "ring"}, + ], + } +] + + BASE_APPEND_UNDERRIDE_RULES = [ { "rule_id": "global/underride/.m.rule.call", @@ -354,6 +507,29 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] +NEW_APPEND_UNDERRIDE_RULES = [ + { + "rule_id": "global/underride/.m.rule.room_one_to_one", + "conditions": [ + {"kind": "room_member_count", "is": "2", "_id": "member_count"}, + {"kind": "event_match", "key": "content.body", "pattern": "*", "_id": "body"}, + ], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + ], + }, + { + "rule_id": "global/underride/.m.rule.message", + "conditions": [ + {"kind": "event_match", "key": "content.body", "pattern": "*", "_id": "body"}, + ], + "actions": ["notify"], + "enabled": False, + }, +] + + BASE_RULE_IDS = set() for r in BASE_APPEND_CONTENT_RULES: @@ -375,3 +551,26 @@ for r in BASE_APPEND_UNDERRIDE_RULES: r["priority_class"] = PRIORITY_CLASS_MAP["underride"] r["default"] = True BASE_RULE_IDS.add(r["rule_id"]) + + +NEW_RULE_IDS = set() + +for r in BASE_APPEND_CONTENT_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["content"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in NEW_PREPEND_OVERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in NEW_APPEND_OVERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["override"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) + +for r in NEW_APPEND_UNDERRIDE_RULES: + r["priority_class"] = PRIORITY_CLASS_MAP["underride"] + r["default"] = True + NEW_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index d181488db7..c10da245d2 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -39,7 +39,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -def _load_rules(rawrules, enabled_map): +def _load_rules(rawrules, enabled_map, use_new_defaults=False): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -49,7 +49,7 @@ def _load_rules(rawrules, enabled_map): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist)) + rules = list(list_with_base_rules(ruleslist, use_new_defaults)) for i, rule in enumerate(rules): rule_id = rule["rule_id"] @@ -115,7 +115,7 @@ class PushRulesWorkerStore( raise NotImplementedError() @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): + def _get_push_rules_for_user(self, user_id, use_new_defaults=False): rows = yield self.db.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, @@ -134,8 +134,22 @@ class PushRulesWorkerStore( enabled_map = yield self.get_push_rules_enabled_for_user(user_id) - rules = _load_rules(rows, enabled_map) + rules = _load_rules(rows, enabled_map, use_new_defaults) + + return rules + + @defer.inlineCallbacks + def get_push_rules_for_user(self, user_id): + # Temporary hack so we can use the new experimental default push rules to some + # users without impacting others. + use_new_defaults = yield self.db.simple_select_list( + table="new_push_rules_users_tmp", + keyvalues={"user_id": user_id}, + retcols=("user_id",), + desc="get_user_new_default_push_rules", + ) + rules = yield self._get_push_rules_for_user(user_id, bool(use_new_defaults)) return rules @cachedInlineCallbacks(max_entries=5000) @@ -194,7 +208,18 @@ class PushRulesWorkerStore( enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): - results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) + # Temporary hack so we can use the new experimental default push rules to some + # users without impacting others. + use_new_defaults = yield self.db.simple_select_list( + table="new_push_rules_users_tmp", + keyvalues={"user_id": user_id}, + retcols=("user_id",), + desc="get_user_new_default_push_rules", + ) + + results[user_id] = _load_rules( + rules, enabled_map_by_user.get(user_id, {}), bool(use_new_defaults), + ) return results diff --git a/synapse/storage/data_stores/main/schema/delta/58/13new_push_rules_tmp.sql b/synapse/storage/data_stores/main/schema/delta/58/13new_push_rules_tmp.sql new file mode 100644 index 0000000000..b7daf1c67b --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/13new_push_rules_tmp.sql @@ -0,0 +1,21 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This is a temporary table in which we store the IDs of the users for which we need to +-- serve the new experimental default push rules. The purpose of this table is to help +-- test these new defaults, so it shall be dropped when the experimentation is done. +CREATE TABLE IF NOT EXISTS new_push_rules_users_tmp ( + user_id TEXT PRIMARY KEY +); \ No newline at end of file -- cgit 1.5.1 From 8dff4a12424cda9e4abaa5f2905d58aa6e723777 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 29 Jul 2020 18:26:55 +0100 Subject: Re-implement unread counts (#7736) --- changelog.d/7736.feature | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 + synapse/push/push_tools.py | 17 +-- synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/data_stores/main/cache.py | 1 + synapse/storage/data_stores/main/events.py | 48 ++++++- synapse/storage/data_stores/main/events_worker.py | 86 ++++++++++- .../main/schema/delta/58/12unread_messages.sql | 18 +++ tests/rest/client/v1/utils.py | 20 +++ tests/rest/client/v2_alpha/test_sync.py | 157 ++++++++++++++++++++- 11 files changed, 339 insertions(+), 18 deletions(-) create mode 100644 changelog.d/7736.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql (limited to 'synapse/push') diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature new file mode 100644 index 0000000000..c97864677a --- /dev/null +++ b/changelog.d/7736.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 22a6abd7d2..bee525197f 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], + "events": ["processed", "outlier", "contains_url", "count_as_unread"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ebd3e98105..eaa4eeadf7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,6 +103,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1886,6 +1887,10 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] + + unread_count = await self.store.get_unread_message_count_for_user( + room_id, sync_config.user.to_string(), + ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1894,6 +1899,7 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666bf..bc8f71916b 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,22 +21,13 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") - badge = len(invites) for room_id in joins: - if room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[room_id] - - notifs = await ( - store.get_unread_event_push_actions_by_room_for_user( - room_id, user_id, last_unread_event_id - ) - ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + unread_count = await store.get_unread_message_count_for_user(room_id, user_id) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if unread_count else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd63..3f5bf75e59 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index f39f556c20..edc3624fed 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) + self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 6f2e0d15cc..0c9c02afa1 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -53,6 +53,47 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + def encode_json(json_object): """ @@ -196,6 +237,10 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() + self.store.get_unread_message_count_for_user.invalidate_many( + (event.room_id,), + ) + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -817,8 +862,9 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), + "count_as_unread": should_count_as_unread(event, context), } - for event, _ in events_and_contexts + for event, context in events_and_contexts ], ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e812c67078..b03b259636 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import ( + Cache, + _CacheContext, + cached, + cachedInlineCallbacks, +) from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + @cached(tree=True, cache_context=True) + async def get_unread_message_count_for_user( + self, room_id: str, user_id: str, cache_context: _CacheContext, + ) -> int: + """Retrieve the count of unread messages for the given room and user. + + Args: + room_id: The ID of the room to count unread messages in. + user_id: The ID of the user to count unread messages for. + + Returns: + The number of unread messages for the given user in the given room. + """ + with Measure(self._clock, "get_unread_message_count_for_user"): + last_read_event_id = await self.get_last_receipt_event_id_for_user( + user_id=user_id, + room_id=room_id, + receipt_type="m.read", + on_invalidate=cache_context.invalidate, + ) + + return await self.db.runInteraction( + "get_unread_message_count_for_user", + self._get_unread_message_count_for_user_txn, + user_id, + room_id, + last_read_event_id, + ) + + def _get_unread_message_count_for_user_txn( + self, + txn: Cursor, + user_id: str, + room_id: str, + last_read_event_id: Optional[str], + ) -> int: + if last_read_event_id: + # Get the stream ordering for the last read event. + stream_ordering = self.db.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, + retcol="stream_ordering", + ) + else: + # If there's no read receipt for that room, it probably means the user hasn't + # opened it yet, in which case use the stream ID of their join event. + # We can't just set it to 0 otherwise messages from other local users from + # before this user joined will be counted as well. + txn.execute( + """ + SELECT stream_ordering FROM local_current_membership + LEFT JOIN events USING (event_id, room_id) + WHERE membership = 'join' + AND user_id = ? + AND room_id = ? + """, + (user_id, room_id), + ) + row = txn.fetchone() + + if row is None: + return 0 + + stream_ordering = row[0] + + # Count the messages that qualify as unread after the stream ordering we've just + # retrieved. + sql = """ + SELECT COUNT(*) FROM events + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread + """ + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + + return row[0] if row else 0 + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql new file mode 100644 index 0000000000..531b532c73 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Store a boolean value in the events table for whether the event should be counted in +-- the unread_count property of sync responses. +ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 22d734e763..7f8252330a 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -143,6 +143,26 @@ class RestHelper(object): return channel.json_body + def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body + def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1bd..a31e44c97e 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] -- cgit 1.5.1 From 60328ce9fbe90299253ba740f2648c42b9091920 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 30 Jul 2020 19:02:28 +0100 Subject: Lint --- synapse/push/baserules.py | 51 +++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 19 deletions(-) (limited to 'synapse/push') diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index e06b1a01e6..172fd00f19 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -96,9 +96,17 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": - rules = NEW_APPEND_OVERRIDE_RULES if use_new_defaults else BASE_APPEND_OVERRIDE_RULES + rules = ( + NEW_APPEND_OVERRIDE_RULES + if use_new_defaults + else BASE_APPEND_OVERRIDE_RULES + ) elif kind == "underride": - rules = NEW_APPEND_UNDERRIDE_RULES if use_new_defaults else BASE_APPEND_UNDERRIDE_RULES + rules = ( + NEW_APPEND_UNDERRIDE_RULES + if use_new_defaults + else BASE_APPEND_UNDERRIDE_RULES + ) elif kind == "content": rules = BASE_APPEND_CONTENT_RULES @@ -117,7 +125,11 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": - rules = NEW_PREPEND_OVERRIDE_RULES if use_new_defaults else BASE_PREPEND_OVERRIDE_RULES + rules = ( + NEW_PREPEND_OVERRIDE_RULES + if use_new_defaults + else BASE_PREPEND_OVERRIDE_RULES + ) # Copy the rules before modifying them rules = copy.deepcopy(rules) @@ -315,7 +327,7 @@ NEW_APPEND_OVERRIDE_RULES = [ "key": "content.msgtype", "pattern": "m.notice", "_id": "_suppress_notices", - } + }, ], "actions": [], }, @@ -348,10 +360,7 @@ NEW_APPEND_OVERRIDE_RULES = [ }, {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, ], - "actions": [ - "notify", - {"set_tweak": "sound", "value": "default"}, - ], + "actions": ["notify", {"set_tweak": "sound", "value": "default"}], }, { "rule_id": "global/override/.m.rule.contains_display_name", @@ -415,11 +424,8 @@ NEW_APPEND_OVERRIDE_RULES = [ "_id": "_call", } ], - "actions": [ - "notify", - {"set_tweak": "sound", "value": "ring"}, - ], - } + "actions": ["notify", {"set_tweak": "sound", "value": "ring"}], + }, ] @@ -512,17 +518,24 @@ NEW_APPEND_UNDERRIDE_RULES = [ "rule_id": "global/underride/.m.rule.room_one_to_one", "conditions": [ {"kind": "room_member_count", "is": "2", "_id": "member_count"}, - {"kind": "event_match", "key": "content.body", "pattern": "*", "_id": "body"}, - ], - "actions": [ - "notify", - {"set_tweak": "sound", "value": "default"}, + { + "kind": "event_match", + "key": "content.body", + "pattern": "*", + "_id": "body", + }, ], + "actions": ["notify", {"set_tweak": "sound", "value": "default"}], }, { "rule_id": "global/underride/.m.rule.message", "conditions": [ - {"kind": "event_match", "key": "content.body", "pattern": "*", "_id": "body"}, + { + "kind": "event_match", + "key": "content.body", + "pattern": "*", + "_id": "body", + }, ], "actions": ["notify"], "enabled": False, -- cgit 1.5.1 From dd11f575a29b59aced6cfa7ea7b9faea6f968f8d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 6 Aug 2020 10:52:26 +0100 Subject: Incorporate review --- synapse/config/server.py | 3 +++ synapse/push/baserules.py | 20 ++++---------------- synapse/rest/client/v1/push_rule.py | 4 ++-- synapse/storage/data_stores/main/push_rule.py | 6 +++--- 4 files changed, 12 insertions(+), 21 deletions(-) (limited to 'synapse/push') diff --git a/synapse/config/server.py b/synapse/config/server.py index 68d143410f..00fa07c225 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -540,6 +540,9 @@ class ServerConfig(Config): if not isinstance(self.users_new_default_push_rules, list): raise ConfigError("'users_new_default_push_rules' must be a list") + # Turn the list into a set to improve lookup speed. + self.users_new_default_push_rules = set(self.users_new_default_push_rules) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 172fd00f19..8047873ff1 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -24,6 +24,8 @@ def list_with_base_rules(rawrules, use_new_defaults=False): Args: rawrules(list): The rules the user has modified or set. + use_new_defaults(bool): Whether to use the new experimental default rules when + appending or prepending default rules. Returns: A new list with the rules set by the user combined with the defaults. @@ -125,11 +127,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): rules = [] if kind == "override": - rules = ( - NEW_PREPEND_OVERRIDE_RULES - if use_new_defaults - else BASE_PREPEND_OVERRIDE_RULES - ) + rules = BASE_PREPEND_OVERRIDE_RULES # Copy the rules before modifying them rules = copy.deepcopy(rules) @@ -171,16 +169,6 @@ BASE_PREPEND_OVERRIDE_RULES = [ ] -NEW_PREPEND_OVERRIDE_RULES = [ - { - "rule_id": "global/override/.m.rule.master", - "enabled": False, - "conditions": [], - "actions": [], - } -] - - BASE_APPEND_OVERRIDE_RULES = [ { "rule_id": "global/override/.m.rule.suppress_notices", @@ -573,7 +561,7 @@ for r in BASE_APPEND_CONTENT_RULES: r["default"] = True NEW_RULE_IDS.add(r["rule_id"]) -for r in NEW_PREPEND_OVERRIDE_RULES: +for r in BASE_PREPEND_OVERRIDE_RULES: r["priority_class"] = PRIORITY_CLASS_MAP["override"] r["default"] = True NEW_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index f66b8fa7c4..00831879f3 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -45,7 +45,7 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None - self.users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = hs.config.users_new_default_push_rules async def on_PUT(self, request, path): if self._is_worker: @@ -181,7 +181,7 @@ class PushRuleRestServlet(RestServlet): rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: - if user_id in self.users_new_default_push_rules: + if user_id in self._users_new_default_push_rules: rule_ids = NEW_RULE_IDS else: rule_ids = BASE_RULE_IDS diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index d644a0b8ce..6b650f8ba8 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -105,7 +105,7 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self.users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = hs.config.users_new_default_push_rules @abc.abstractmethod def get_max_push_rules_stream_id(self): @@ -136,7 +136,7 @@ class PushRulesWorkerStore( enabled_map = yield self.get_push_rules_enabled_for_user(user_id) - use_new_defaults = user_id in self.users_new_default_push_rules + use_new_defaults = user_id in self._users_new_default_push_rules rules = _load_rules(rows, enabled_map, use_new_defaults) @@ -198,7 +198,7 @@ class PushRulesWorkerStore( enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): - use_new_defaults = user_id in self.users_new_default_push_rules + use_new_defaults = user_id in self._users_new_default_push_rules results[user_id] = _load_rules( rules, enabled_map_by_user.get(user_id, {}), use_new_defaults, -- cgit 1.5.1 From d4a7829b12197faf52eb487c443ee09acafeb37e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Aug 2020 08:30:06 -0400 Subject: Convert synapse.api to async/await (#8031) --- changelog.d/8031.misc | 1 + synapse/api/auth.py | 123 ++++++++++----------- synapse/api/auth_blocking.py | 13 +-- synapse/api/filtering.py | 7 +- synapse/events/builder.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/module_api/__init__.py | 8 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/rest/client/v1/directory.py | 2 +- synapse/rest/client/v2_alpha/register.py | 2 +- synapse/storage/databases/main/client_ips.py | 5 +- tests/api/test_auth.py | 69 +++++++----- tests/api/test_filtering.py | 36 ++++-- tests/handlers/test_typing.py | 4 +- tests/rest/admin/test_user.py | 10 +- tests/rest/client/v1/test_profile.py | 4 +- tests/rest/client/v1/test_rooms.py | 6 +- tests/rest/client/v1/test_typing.py | 6 +- .../test_resource_limits_server_notices.py | 2 +- tests/unittest.py | 24 ++-- 22 files changed, 172 insertions(+), 160 deletions(-) create mode 100644 changelog.d/8031.misc (limited to 'synapse/push') diff --git a/changelog.d/8031.misc b/changelog.d/8031.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8031.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 2178e623da..d8190f92ab 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import List, Optional, Tuple import pymacaroons from netaddr import IPAddress -from twisted.internet import defer from twisted.web.server import Request import synapse.types @@ -80,13 +79,14 @@ class Auth(object): self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key - @defer.inlineCallbacks - def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) - auth_events_ids = yield self.compute_auth_events( + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ): + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -94,14 +94,13 @@ class Auth(object): room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check ) - @defer.inlineCallbacks - def check_user_in_room( + async def check_user_in_room( self, room_id: str, user_id: str, current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ): + ) -> EventBase: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -119,37 +118,35 @@ class Auth(object): Raises: AuthError if the user is/was not in the room. Returns: - Deferred[Optional[EventBase]]: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + Membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. """ if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield defer.ensureDeferred( - self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) + member = await self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) - membership = member.membership if member else None - if membership == Membership.JOIN: - return member + if member: + membership = member.membership - # XXX this looks totally bogus. Why do we not allow users who have been banned, - # or those who were members previously and have been re-invited? - if allow_departed_users and membership == Membership.LEAVE: - forgot = yield self.store.did_forget(user_id, room_id) - if not forgot: + if membership == Membership.JOIN: return member + # XXX this looks totally bogus. Why do we not allow users who have been banned, + # or those who were members previously and have been re-invited? + if allow_departed_users and membership == Membership.LEAVE: + forgot = await self.store.did_forget(user_id, room_id) + if not forgot: + return member + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - @defer.inlineCallbacks - def check_host_in_room(self, room_id, host): + async def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - latest_event_ids = yield self.store.is_host_joined(room_id, host) + latest_event_ids = await self.store.is_host_joined(room_id, host) return latest_event_ids def can_federate(self, event, auth_events): @@ -160,14 +157,13 @@ class Auth(object): def get_public_keys(self, invite_event): return event_auth.get_public_keys(invite_event) - @defer.inlineCallbacks - def get_user_by_req( + async def get_user_by_req( self, request: Request, allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, - ): + ) -> synapse.types.Requester: """ Get a registered user's ID. Args: @@ -180,7 +176,7 @@ class Auth(object): /login will deliver access tokens regardless of expiration. Returns: - defer.Deferred: resolves to a `synapse.types.Requester` object + Resolves to the requester Raises: InvalidClientCredentialsError if no user by that token exists or the token is invalid. @@ -194,14 +190,14 @@ class Auth(object): access_token = self.get_access_token_from_request(request) - user_id, app_service = yield self._get_appservice_user_id(request) + user_id, app_service = await self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) if ip_addr and self._track_appservice_user_ips: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user_id, access_token=access_token, ip=ip_addr, @@ -211,7 +207,7 @@ class Auth(object): return synapse.types.create_requester(user_id, app_service=app_service) - user_info = yield self.get_user_by_access_token( + user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) user = user_info["user"] @@ -221,7 +217,7 @@ class Auth(object): # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) if ( expiration_ts is not None and self.clock.time_msec() >= expiration_ts @@ -235,7 +231,7 @@ class Auth(object): device_id = user_info.get("device_id") if user and access_token and ip_addr: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user.to_string(), access_token=access_token, ip=ip_addr, @@ -261,8 +257,7 @@ class Auth(object): except KeyError: raise MissingClientTokenError() - @defer.inlineCallbacks - def _get_appservice_user_id(self, request): + async def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) @@ -283,14 +278,13 @@ class Auth(object): if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") - if not (yield self.store.get_user_by_id(user_id)): + if not (await self.store.get_user_by_id(user_id)): raise AuthError(403, "Application service has not registered this user") return user_id, app_service - @defer.inlineCallbacks - def get_user_by_access_token( + async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ): + ) -> dict: """ Validate access token and get user_id from it Args: @@ -300,7 +294,7 @@ class Auth(object): allow_expired: If False, raises an InvalidClientTokenError if the token is expired Returns: - Deferred[dict]: dict that includes: + dict that includes: `user` (UserID) `is_guest` (bool) `token_id` (int|None): access token id. May be None if guest @@ -314,7 +308,7 @@ class Auth(object): if rights == "access": # first look in the database - r = yield self._look_up_user_by_access_token(token) + r = await self._look_up_user_by_access_token(token) if r: valid_until_ms = r["valid_until_ms"] if ( @@ -352,7 +346,7 @@ class Auth(object): # It would of course be much easier to store guest access # tokens in the database as well, but that would break existing # guest tokens. - stored_user = yield self.store.get_user_by_id(user_id) + stored_user = await self.store.get_user_by_id(user_id) if not stored_user: raise InvalidClientTokenError("Unknown user_id %s" % user_id) if not stored_user["is_guest"]: @@ -482,9 +476,8 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - @defer.inlineCallbacks - def _look_up_user_by_access_token(self, token): - ret = yield self.store.get_user_by_access_token(token) + async def _look_up_user_by_access_token(self, token): + ret = await self.store.get_user_by_access_token(token) if not ret: return None @@ -507,7 +500,7 @@ class Auth(object): logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender - return defer.succeed(service) + return service async def is_server_admin(self, user: UserID) -> bool: """ Check if the given user is a local server admin. @@ -522,7 +515,7 @@ class Auth(object): def compute_auth_events( self, event, current_state_ids: StateMap[str], for_verification: bool = False, - ): + ) -> List[str]: """Given an event and current state return the list of event IDs used to auth an event. @@ -530,11 +523,11 @@ class Auth(object): should be added to the event's `auth_events`. Returns: - defer.Deferred(list[str]): List of event IDs. + List of event IDs. """ if event.type == EventTypes.Create: - return defer.succeed([]) + return [] # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding @@ -553,7 +546,7 @@ class Auth(object): if auth_ev_id: auth_ids.append(auth_ev_id) - return defer.succeed(auth_ids) + return auth_ids async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the @@ -636,10 +629,9 @@ class Auth(object): return query_params[0].decode("ascii") - @defer.inlineCallbacks - def check_user_in_room_or_world_readable( + async def check_user_in_room_or_world_readable( self, room_id: str, user_id: str, allow_departed_users: bool = False - ): + ) -> Tuple[str, Optional[str]]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. @@ -650,10 +642,9 @@ class Auth(object): members but have now departed Returns: - Deferred[tuple[str, str|None]]: Resolves to the current membership of - the user in the room and the membership event ID of the user. If - the user is not in the room and never has been, then - `(Membership.JOIN, None)` is returned. + Resolves to the current membership of the user in the room and the + membership event ID of the user. If the user is not in the room and + never has been, then `(Membership.JOIN, None)` is returned. """ try: @@ -662,15 +653,13 @@ class Auth(object): # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.check_user_in_room( + member_event = await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield defer.ensureDeferred( - self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) + visibility = await self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" ) if ( visibility diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 5c499b6b4e..49093bf181 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved @@ -36,8 +34,7 @@ class AuthBlocking(object): self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag @@ -60,7 +57,7 @@ class AuthBlocking(object): if user_id is not None: if user_id == self._server_notices_mxid: return - if (yield self.store.is_support_user(user_id)): + if await self.store.is_support_user(user_id): return if self._hs_disabled: @@ -76,11 +73,11 @@ class AuthBlocking(object): # If the user is already part of the MAU cohort or a trial user if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) + timestamp = await self.store.user_last_seen_monthly_active(user_id) if timestamp: return - is_trial = yield self.store.is_trial_user(user_id) + is_trial = await self.store.is_trial_user(user_id) if is_trial: return elif threepid: @@ -93,7 +90,7 @@ class AuthBlocking(object): # allow registration. Support users are excluded from MAU checks. return # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() + current_mau = await self.store.get_monthly_active_count() if current_mau >= self._max_mau_value: raise ResourceLimitError( 403, diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f988f62a1e..7393d6cb74 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -21,8 +21,6 @@ import jsonschema from canonicaljson import json from jsonschema import FormatChecker -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState @@ -137,9 +135,8 @@ class Filtering(object): super(Filtering, self).__init__() self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_user_filter(self, user_localpart, filter_id): - result = yield self.store.get_user_filter(user_localpart, filter_id) + async def get_user_filter(self, user_localpart, filter_id): + result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(result) def add_user_filter(self, user_localpart, user_filter): diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 69b53ca2bc..4e179d49b3 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -106,7 +106,7 @@ class EventBuilder(object): state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_ids = await self._auth.compute_auth_events(self, state_ids) + auth_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b3764dedae..593932adb7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler): if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_x = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 43901d0934..708533d4d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1061,7 +1061,7 @@ class EventCreationHandler(object): raise SynapseError(400, "Cannot redact event from a different room") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8201849951..c2fb757d9a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,12 +194,16 @@ class ModuleApi(object): synapse.api.errors.AuthError: the access token is invalid """ # see if the access token corresponds to a device - user_info = yield self._auth.get_user_by_access_token(access_token) + user_info = yield defer.ensureDeferred( + self._auth.get_user_by_access_token(access_token) + ) device_id = user_info.get("device_id") user_id = user_info["user"].to_string() if device_id: # delete the device, which will also delete its access tokens - yield self._hs.get_device_handler().delete_device(user_id, device_id) + yield defer.ensureDeferred( + self._hs.get_device_handler().delete_device(user_id, device_id) + ) else: # no associated device. Just delete the access token. yield defer.ensureDeferred( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 04b9d8ac82..e7fcee0e87 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object): pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 60dd3f6701..a6fdedde63 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -28,7 +28,7 @@ class SlavedClientIpStore(BaseSlavedStore): name="client_ip_last_seen", keylen=4, max_entries=50000 ) - def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 5934b1fe8b..b210015173 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler try: - service = await self.auth.get_appservice_by_req(request) + service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) await dir_handler.delete_appservice_association(service, room_alias) logger.info( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a4c079196d..c549c090b3 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -424,7 +424,7 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = await self.auth.get_appservice_by_req(request) + appservice = self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 712c8d0264..50d71f5ebc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -380,8 +380,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - @defer.inlineCallbacks - def insert_client_ip( + async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): if not now: @@ -392,7 +391,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - yield self.populate_monthly_active_users(user_id) + await self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 0bfb86bf1f..5d45689c8c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) + self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) self.store.is_support_user = Mock(return_value=defer.succeed(False)) @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = {"name": self.test_user, "token_id": "ditto"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" @@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") @@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + # This just needs to return a truth-y value. + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) self.failureResultOf(d, AuthError) @defer.inlineCallbacks def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org", "device_id": "device"} + return_value=defer.succeed( + {"name": "@baldrick:matrix.org", "device_id": "device"} + ) ) user_id = "@baldrick:matrix.org" @@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value={"is_guest": True}) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase): def get_user(tok): if token != tok: - return None - return { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + return defer.succeed(None) + return defer.succeed( + { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock(return_value={"is_guest": False}) + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) # check the token works request = Mock(args={}) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 4e67503cf0..1fab1d6b69 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart + "2", filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events=events) @@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events) @@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals( user_filter_json, ( - yield self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 + yield defer.ensureDeferred( + self.datastore.get_user_filter( + user_localpart=user_localpart, filter_id=0 + ) ) ), ) @@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase): user_localpart=user_localpart, user_filter=user_filter_json ) - filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) self.assertEquals(filter.get_filter_json(), user_filter_json) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5878f74175..b7d0adb10e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - return defer.succeed(None) + return None hs.get_auth().check_user_in_room = check_user_in_room diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index f16eef15f7..17d0aae2e9 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -20,6 +20,8 @@ import urllib.parse from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import HttpResponseException, ResourceLimitError @@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() # Set monthly active users to the limit - store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit self.get_failure( @@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 8df58b4a63..ace0a3c08d 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase): profile_handler=self.mock_handler, ) - def _get_user_by_req(request=None, allow_guest=False): - return defer.succeed(synapse.types.create_requester(myid)) + async def _get_user_by_req(request=None, allow_guest=False): + return synapse.types.create_requester(myid) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5ccda8b2bd..ef6b775ed2 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -23,8 +23,6 @@ from urllib import parse as urlparse from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus @@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None self.hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 18260bb90e..94d2bf2eb1 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_handlers().federation_handler = Mock() - def get_user_by_access_token(token=None, allow_guest=False): + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, @@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 7f70353b0d..3f88abe3d2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=1000) + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) diff --git a/tests/unittest.py b/tests/unittest.py index 2152c693f2..d0bba3ddef 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: - def get_user_by_access_token(token=None, allow_guest=False): - return succeed( - { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - ) - - def get_user_by_req(request, allow_guest=False, rights="access"): - return succeed( - create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + async def get_user_by_req(request, allow_guest=False, rights="access"): + return create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None ) self.hs.get_auth().get_user_by_req = get_user_by_req -- cgit 1.5.1 From 2ffd6783c7af12e3c29e1a44dee4a9deeb83890b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 6 Aug 2020 17:15:35 +0100 Subject: Revert #7736 (#8039) --- changelog.d/7736.feature | 1 - changelog.d/8039.misc | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 - synapse/push/push_tools.py | 17 ++- synapse/rest/client/v2_alpha/sync.py | 1 - synapse/storage/databases/main/cache.py | 1 - synapse/storage/databases/main/events.py | 48 +------ synapse/storage/databases/main/events_worker.py | 86 +---------- .../main/schema/delta/58/12unread_messages.sql | 18 --- tests/rest/client/v1/utils.py | 20 --- tests/rest/client/v2_alpha/test_sync.py | 157 +-------------------- 12 files changed, 19 insertions(+), 339 deletions(-) delete mode 100644 changelog.d/7736.feature create mode 100644 changelog.d/8039.misc delete mode 100644 synapse/storage/databases/main/schema/delta/58/12unread_messages.sql (limited to 'synapse/push') diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature deleted file mode 100644 index feb02be234..0000000000 --- a/changelog.d/7736.feature +++ /dev/null @@ -1 +0,0 @@ -Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). diff --git a/changelog.d/8039.misc b/changelog.d/8039.misc new file mode 100644 index 0000000000..599933c80e --- /dev/null +++ b/changelog.d/8039.misc @@ -0,0 +1 @@ +Revert MSC2654 implementation because of perf issues. Please delete this line when processing the 1.19 changelog. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index ae5e1810fc..a34bdf1830 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -67,7 +67,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url", "count_as_unread"], + "events": ["processed", "outlier", "contains_url"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5a19bac929..c42dac18f5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,7 +103,6 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) - unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1887,10 +1886,6 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] - - unread_count = await self.store.get_unread_message_count_for_user( - room_id, sync_config.user.to_string(), - ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1899,7 +1894,6 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, - unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index bc8f71916b..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,13 +21,22 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + badge = len(invites) for room_id in joins: - unread_count = await store.get_unread_message_count_for_user(room_id, user_id) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if unread_count else 0 + if room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[room_id] + + notifs = await ( + store.get_unread_event_push_actions_by_room_for_user( + room_id, user_id, last_unread_event_id + ) + ) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if notifs["notify_count"] else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 3f5bf75e59..a5c24fbd63 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,7 +426,6 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 683afde52b..10de446065 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -172,7 +172,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 4d8a24ce4b..1a68bf32cb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -53,47 +53,6 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) -STATE_EVENT_TYPES_TO_MARK_UNREAD = { - EventTypes.Topic, - EventTypes.Name, - EventTypes.RoomAvatar, - EventTypes.Tombstone, -} - - -def should_count_as_unread(event: EventBase, context: EventContext) -> bool: - # Exclude rejected and soft-failed events. - if context.rejected or event.internal_metadata.is_soft_failed(): - return False - - # Exclude notices. - if ( - not event.is_state() - and event.type == EventTypes.Message - and event.content.get("msgtype") == "m.notice" - ): - return False - - # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: - return False - - # Mark events that have a non-empty string body as unread. - body = event.content.get("body") - if isinstance(body, str) and body: - return True - - # Mark some state events as unread. - if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: - return True - - # Mark encrypted events as unread. - if not event.is_state() and event.type == EventTypes.Encrypted: - return True - - return False - def encode_json(json_object): """ @@ -239,10 +198,6 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - self.store.get_unread_message_count_for_user.invalidate_many( - (event.room_id,), - ) - for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -864,9 +819,8 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), - "count_as_unread": should_count_as_unread(event, context), } - for event, context in events_and_contexts + for event, _ in events_and_contexts ], ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a7b7393f6e..755b7a2a85 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -41,15 +41,9 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import ( - Cache, - _CacheContext, - cached, - cachedInlineCallbacks, -) +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1364,84 +1358,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - @cached(tree=True, cache_context=True) - async def get_unread_message_count_for_user( - self, room_id: str, user_id: str, cache_context: _CacheContext, - ) -> int: - """Retrieve the count of unread messages for the given room and user. - - Args: - room_id: The ID of the room to count unread messages in. - user_id: The ID of the user to count unread messages for. - - Returns: - The number of unread messages for the given user in the given room. - """ - with Measure(self._clock, "get_unread_message_count_for_user"): - last_read_event_id = await self.get_last_receipt_event_id_for_user( - user_id=user_id, - room_id=room_id, - receipt_type="m.read", - on_invalidate=cache_context.invalidate, - ) - - return await self.db_pool.runInteraction( - "get_unread_message_count_for_user", - self._get_unread_message_count_for_user_txn, - user_id, - room_id, - last_read_event_id, - ) - - def _get_unread_message_count_for_user_txn( - self, - txn: Cursor, - user_id: str, - room_id: str, - last_read_event_id: Optional[str], - ) -> int: - if last_read_event_id: - # Get the stream ordering for the last read event. - stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn=txn, - table="events", - keyvalues={"room_id": room_id, "event_id": last_read_event_id}, - retcol="stream_ordering", - ) - else: - # If there's no read receipt for that room, it probably means the user hasn't - # opened it yet, in which case use the stream ID of their join event. - # We can't just set it to 0 otherwise messages from other local users from - # before this user joined will be counted as well. - txn.execute( - """ - SELECT stream_ordering FROM local_current_membership - LEFT JOIN events USING (event_id, room_id) - WHERE membership = 'join' - AND user_id = ? - AND room_id = ? - """, - (user_id, room_id), - ) - row = txn.fetchone() - - if row is None: - return 0 - - stream_ordering = row[0] - - # Count the messages that qualify as unread after the stream ordering we've just - # retrieved. - sql = """ - SELECT COUNT(*) FROM events - WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread - """ - - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() - - return row[0] if row else 0 - AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql b/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql deleted file mode 100644 index 531b532c73..0000000000 --- a/synapse/storage/databases/main/schema/delta/58/12unread_messages.sql +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Store a boolean value in the events table for whether the event should be counted in --- the unread_count property of sync responses. -ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 51941f99f9..8933b560d2 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -165,26 +165,6 @@ class RestHelper(object): return channel.json_body - def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - request, channel = make_request( - self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") - ) - render(request, self.resource, self.hs.get_reactor()) - - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) - ) - - return channel.json_body - def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index a31e44c97e..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import read_marker, sync +from synapse.rest.client.v2_alpha import sync from tests import unittest from tests.server import TimedOutException @@ -324,156 +324,3 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - request, channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.render(request) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: True): - """Syncs and compares the unread count with the expected value.""" - - request, channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, - ) - self.render(request) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] -- cgit 1.5.1 From e04e465b4d2c66acb8885c31736c7b7bb4e7be52 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 17 Aug 2020 17:05:00 +0100 Subject: Use the default templates when a custom template file cannot be found (#8037) Fixes https://github.com/matrix-org/synapse/issues/6583 --- changelog.d/8037.feature | 1 + docs/sample_config.yaml | 4 +- synapse/config/_base.py | 100 ++++++++++++++++++++- synapse/config/emailconfig.py | 145 ++++++++++++++----------------- synapse/config/saml2_config.py | 14 +-- synapse/config/sso.py | 37 ++++---- synapse/handlers/account_validity.py | 20 +---- synapse/handlers/auth.py | 12 ++- synapse/handlers/oidc_handler.py | 5 +- synapse/push/mailer.py | 72 +-------------- synapse/push/pusher.py | 31 ++----- synapse/python_dependencies.py | 2 - synapse/rest/client/v2_alpha/account.py | 44 +++------- synapse/rest/client/v2_alpha/register.py | 31 ++----- tests/config/test_base.py | 82 +++++++++++++++++ 15 files changed, 310 insertions(+), 290 deletions(-) create mode 100644 changelog.d/8037.feature create mode 100644 tests/config/test_base.py (limited to 'synapse/push') diff --git a/changelog.d/8037.feature b/changelog.d/8037.feature new file mode 100644 index 0000000000..2e5127477d --- /dev/null +++ b/changelog.d/8037.feature @@ -0,0 +1 @@ +Use the default template file when its equivalent is not found in a custom template directory. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9235b89fb1..f168853f67 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2002,9 +2002,7 @@ email: # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index fd137853b1..1417487427 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,12 +18,16 @@ import argparse import errno import os +import time +import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, List, MutableMapping, Optional +from typing import Any, Callable, List, MutableMapping, Optional import attr +import jinja2 +import pkg_resources import yaml @@ -100,6 +104,11 @@ class Config(object): def __init__(self, root_config=None): self.root = root_config + # Get the path to the default Synapse template directory + self.default_template_dir = pkg_resources.resource_filename( + "synapse", "res/templates" + ) + def __getattr__(self, item: str) -> Any: """ Try and fetch a configuration option that does not exist on this class. @@ -184,6 +193,95 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() + def read_templates( + self, filenames: List[str], custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Load a list of template files from disk using the given variables. + + This function will attempt to load the given templates from the default Synapse + template directory. If `custom_template_directory` is supplied, that directory + is tried first. + + Files read are treated as Jinja templates. These templates are not rendered yet. + + Args: + filenames: A list of template filenames to read. + + custom_template_directory: A directory to try to look for the templates + before using the default Synapse template directory instead. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A list of jinja2 templates. + """ + templates = [] + search_directories = [self.default_template_dir] + + # The loader will first look in the custom template directory (if specified) for the + # given filename. If it doesn't find it, it will use the default template dir instead + if custom_template_directory: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.insert(0, custom_template_directory) + + loader = jinja2.FileSystemLoader(search_directories) + env = jinja2.Environment(loader=loader, autoescape=True) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) + + for filename in filenames: + # Load the template + template = env.get_template(filename) + templates.append(template) + + return templates + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + class RootConfig(object): """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index a63acbdc63..7a796996c0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -23,7 +23,6 @@ from enum import Enum from typing import Optional import attr -import pkg_resources from ._base import Config, ConfigError @@ -98,21 +97,18 @@ class EmailConfig(Config): if parsed[1] == "": raise RuntimeError("Invalid notif_from address") + # A user-configurable template directory template_dir = email_config.get("template_dir") - # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxiliary templates like mail.css we will need). - # (Note that loading as package_resources with jinja.PackageLoader doesn't - # work for the same reason.) - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - self.email_template_dir = os.path.abspath(template_dir) + if isinstance(template_dir, str): + # We need an absolute path, because we change directory after starting (and + # we don't yet know what auxiliary templates like mail.css we will need). + template_dir = os.path.abspath(template_dir) + elif template_dir is not None: + # If template_dir is something other than a str or None, warn the user + raise ConfigError("Config option email.template_dir must be type str") self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_config = config.get("account_validity") or {} - account_validity_renewal_enabled = account_validity_config.get("renew_at") - self.threepid_behaviour_email = ( # Have Synapse handle the email sending if account_threepid_delegates.email # is not defined @@ -166,19 +162,6 @@ class EmailConfig(Config): email_config.get("validation_token_lifetime", "1h") ) - if ( - self.email_enable_notifs - or account_validity_renewal_enabled - or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): - # make sure we can import the required deps - import bleach - import jinja2 - - # prevent unused warnings - jinja2 - bleach - if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: missing = [] if not self.email_notif_from: @@ -196,49 +179,49 @@ class EmailConfig(Config): # These email templates have placeholders in them, and thus must be # parsed using a templating engine during a request - self.email_password_reset_template_html = email_config.get( + password_reset_template_html = email_config.get( "password_reset_template_html", "password_reset.html" ) - self.email_password_reset_template_text = email_config.get( + password_reset_template_text = email_config.get( "password_reset_template_text", "password_reset.txt" ) - self.email_registration_template_html = email_config.get( + registration_template_html = email_config.get( "registration_template_html", "registration.html" ) - self.email_registration_template_text = email_config.get( + registration_template_text = email_config.get( "registration_template_text", "registration.txt" ) - self.email_add_threepid_template_html = email_config.get( + add_threepid_template_html = email_config.get( "add_threepid_template_html", "add_threepid.html" ) - self.email_add_threepid_template_text = email_config.get( + add_threepid_template_text = email_config.get( "add_threepid_template_text", "add_threepid.txt" ) - self.email_password_reset_template_failure_html = email_config.get( + password_reset_template_failure_html = email_config.get( "password_reset_template_failure_html", "password_reset_failure.html" ) - self.email_registration_template_failure_html = email_config.get( + registration_template_failure_html = email_config.get( "registration_template_failure_html", "registration_failure.html" ) - self.email_add_threepid_template_failure_html = email_config.get( + add_threepid_template_failure_html = email_config.get( "add_threepid_template_failure_html", "add_threepid_failure.html" ) # These templates do not support any placeholder variables, so we # will read them from disk once during setup - email_password_reset_template_success_html = email_config.get( + password_reset_template_success_html = email_config.get( "password_reset_template_success_html", "password_reset_success.html" ) - email_registration_template_success_html = email_config.get( + registration_template_success_html = email_config.get( "registration_template_success_html", "registration_success.html" ) - email_add_threepid_template_success_html = email_config.get( + add_threepid_template_success_html = email_config.get( "add_threepid_template_success_html", "add_threepid_success.html" ) - # Check templates exist - for f in [ + # Read all templates from disk + ( self.email_password_reset_template_html, self.email_password_reset_template_text, self.email_registration_template_html, @@ -248,32 +231,36 @@ class EmailConfig(Config): self.email_password_reset_template_failure_html, self.email_registration_template_failure_html, self.email_add_threepid_template_failure_html, - email_password_reset_template_success_html, - email_registration_template_success_html, - email_add_threepid_template_success_html, - ]: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p,)) - - # Retrieve content of web templates - filepath = os.path.join( - self.email_template_dir, email_password_reset_template_success_html + password_reset_template_success_html_template, + registration_template_success_html_template, + add_threepid_template_success_html_template, + ) = self.read_templates( + [ + password_reset_template_html, + password_reset_template_text, + registration_template_html, + registration_template_text, + add_threepid_template_html, + add_threepid_template_text, + password_reset_template_failure_html, + registration_template_failure_html, + add_threepid_template_failure_html, + password_reset_template_success_html, + registration_template_success_html, + add_threepid_template_success_html, + ], + template_dir, ) - self.email_password_reset_template_success_html = self.read_file( - filepath, "email.password_reset_template_success_html" - ) - filepath = os.path.join( - self.email_template_dir, email_registration_template_success_html - ) - self.email_registration_template_success_html_content = self.read_file( - filepath, "email.registration_template_success_html" + + # Render templates that do not contain any placeholders + self.email_password_reset_template_success_html_content = ( + password_reset_template_success_html_template.render() ) - filepath = os.path.join( - self.email_template_dir, email_add_threepid_template_success_html + self.email_registration_template_success_html_content = ( + registration_template_success_html_template.render() ) - self.email_add_threepid_template_success_html_content = self.read_file( - filepath, "email.add_threepid_template_success_html" + self.email_add_threepid_template_success_html_content = ( + add_threepid_template_success_html_template.render() ) if self.email_enable_notifs: @@ -290,17 +277,19 @@ class EmailConfig(Config): % (", ".join(missing),) ) - self.email_notif_template_html = email_config.get( + notif_template_html = email_config.get( "notif_template_html", "notif_mail.html" ) - self.email_notif_template_text = email_config.get( + notif_template_text = email_config.get( "notif_template_text", "notif_mail.txt" ) - for f in self.email_notif_template_text, self.email_notif_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.email_notif_template_html, + self.email_notif_template_text, + ) = self.read_templates( + [notif_template_html, notif_template_text], template_dir, + ) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True @@ -309,18 +298,20 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if account_validity_renewal_enabled: - self.email_expiry_template_html = email_config.get( + if self.account_validity.renew_by_email_enabled: + expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) - self.email_expiry_template_text = email_config.get( + expiry_template_text = email_config.get( "expiry_template_text", "notice_expiry.txt" ) - for f in self.email_expiry_template_text, self.email_expiry_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.account_validity_template_html, + self.account_validity_template_text, + ) = self.read_templates( + [expiry_template_html, expiry_template_text], template_dir, + ) subjects_config = email_config.get("subjects", {}) subjects = {} @@ -400,9 +391,7 @@ class EmailConfig(Config): # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 9277b5f342..036f8c0e90 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -18,8 +18,6 @@ import logging from typing import Any, List import attr -import jinja2 -import pkg_resources from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module @@ -171,15 +169,9 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "15m") ) - template_dir = saml2_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - - loader = jinja2.FileSystemLoader(template_dir) - # enable auto-escape here, to having to remember to escape manually in the - # template - env = jinja2.Environment(loader=loader, autoescape=True) - self.saml2_error_html_template = env.get_template("saml_error.html") + self.saml2_error_html_template = self.read_templates( + ["saml_error.html"], saml2_config.get("template_dir") + ) def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 73b7296399..4427676167 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,11 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Dict -import pkg_resources - from ._base import Config @@ -29,22 +26,32 @@ class SSOConfig(Config): def read_config(self, config, **kwargs): sso_config = config.get("sso") or {} # type: Dict[str, Any] - # Pick a template directory in order of: - # * The sso-specific template_dir - # * /path/to/synapse/install/res/templates + # The sso-specific template_dir template_dir = sso_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.sso_template_dir = template_dir - self.sso_account_deactivated_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), - "sso_account_deactivated_template", + # Read templates from disk + ( + self.sso_redirect_confirm_template, + self.sso_auth_confirm_template, + self.sso_error_template, + sso_account_deactivated_template, + sso_auth_success_template, + ) = self.read_templates( + [ + "sso_redirect_confirm.html", + "sso_auth_confirm.html", + "sso_error.html", + "sso_account_deactivated.html", + "sso_auth_success.html", + ], + template_dir, ) - self.sso_auth_success_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_auth_success.html"), - "sso_auth_success_template", + + # These templates have no placeholders, so render them here + self.sso_account_deactivated_template = ( + sso_account_deactivated_template.render() ) + self.sso_auth_success_template = sso_auth_success_template.render() self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 590135d19c..b865bf5b48 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util import stringutils -try: - from synapse.push.mailer import load_jinja2_templates -except ImportError: - load_jinja2_templates = None - logger = logging.getLogger(__name__) @@ -47,9 +42,11 @@ class AccountValidityHandler(object): if ( self._account_validity.enabled and self._account_validity.renew_by_email_enabled - and load_jinja2_templates ): # Don't do email-specific configuration if renewal by email is disabled. + self._template_html = self.config.account_validity_template_html + self._template_text = self.config.account_validity_template_text + try: app_name = self.hs.config.email_app_name @@ -65,17 +62,6 @@ class AccountValidityHandler(object): self._raw_from = email.utils.parseaddr(self._from_string)[1] - self._template_html, self._template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_expiry_template_html, - self.config.email_expiry_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) - # Check the renewal emails to send and send them every 30min. def send_emails(): # run as a background process to make sure that the database transactions diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c24e7bafe0..68d6870e40 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.threepids import canonicalise_email @@ -132,18 +131,17 @@ class AuthHandler(BaseHandler): # after the SSO completes and before redirecting them back to their client. # It notifies the user they are about to give access to their matrix account # to the client. - self._sso_redirect_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_redirect_confirm.html"], - )[0] + self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template + # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. - self._sso_auth_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_auth_confirm.html"], - )[0] + self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template + # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index fa5ee5de8f..87d28a7ae9 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -38,7 +38,6 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.push.mailer import load_jinja2_templates from synapse.types import UserID, map_username_to_mxid_localpart if TYPE_CHECKING: @@ -123,9 +122,7 @@ class OidcHandler: self._hostname = hs.hostname # type: str self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key - self._error_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_error.html"] - )[0] + self._error_template = hs.config.sso_error_template # identifier for the external_ids table self._auth_provider_id = "oidc" diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index af117fddf9..c38e037281 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -16,8 +16,7 @@ import email.mime.multipart import email.utils import logging -import time -import urllib +import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Iterable, List, TypeVar @@ -640,72 +639,3 @@ def string_ordinal_total(s): for c in s: tot += ord(c) return tot - - -def format_ts_filter(value, format): - return time.strftime(format, time.localtime(value / 1000)) - - -def load_jinja2_templates( - template_dir, - template_filenames, - apply_format_ts_filter=False, - apply_mxc_to_http_filter=False, - public_baseurl=None, -): - """Loads and returns one or more jinja2 templates and applies optional filters - - Args: - template_dir (str): The directory where templates are stored - template_filenames (list[str]): A list of template filenames - apply_format_ts_filter (bool): Whether to apply a template filter that formats - timestamps - apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts - mxc urls to http urls - public_baseurl (str|None): The public baseurl of the server. Required for - apply_mxc_to_http_filter to be enabled - - Returns: - A list of jinja2 templates corresponding to the given list of filenames, - with order preserved - """ - logger.info( - "loading email templates %s from '%s'", template_filenames, template_dir - ) - loader = jinja2.FileSystemLoader(template_dir) - env = jinja2.Environment(loader=loader) - - if apply_format_ts_filter: - env.filters["format_ts"] = format_ts_filter - - if apply_mxc_to_http_filter and public_baseurl: - env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl) - - templates = [] - for template_filename in template_filenames: - template = env.get_template(template_filename) - templates.append(template) - - return templates - - -def _create_mxc_to_http_filter(public_baseurl): - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - serverAndMediaId = value[6:] - fragment = None - if "#" in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - serverAndMediaId, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 8ad0bf5936..f626797133 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -15,22 +15,13 @@ import logging +from synapse.push.emailpusher import EmailPusher +from synapse.push.mailer import Mailer + from .httppusher import HttpPusher logger = logging.getLogger(__name__) -# We try importing this if we can (it will fail if we don't -# have the optional email dependencies installed). We don't -# yet have the config to know if we need the email pusher, -# but importing this after daemonizing seems to fail -# (even though a simple test of importing from a daemonized -# process works fine) -try: - from synapse.push.emailpusher import EmailPusher - from synapse.push.mailer import Mailer, load_jinja2_templates -except Exception: - pass - class PusherFactory(object): def __init__(self, hs): @@ -43,16 +34,8 @@ class PusherFactory(object): if hs.config.email_enable_notifs: self.mailers = {} # app_name -> Mailer - self.notif_template_html, self.notif_template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_notif_template_html, - self.config.email_notif_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) + self._notif_template_html = hs.config.email_notif_template_html + self._notif_template_text = hs.config.email_notif_template_text self.pusher_types["email"] = self._create_email_pusher @@ -73,8 +56,8 @@ class PusherFactory(object): mailer = Mailer( hs=self.hs, app_name=app_name, - template_html=self.notif_template_html, - template_text=self.notif_template_text, + template_html=self._notif_template_html, + template_text=self._notif_template_text, ) self.mailers[app_name] = mailer return EmailPusher(self.hs, pusherdict, mailer) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index e5f22fb858..3250d41dde 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 - "resources.consent": ["Jinja2>=2.9"], # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index fead85074b..203e76b9f2 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -32,7 +32,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import canonicalise_email, check_3pid_allowed @@ -53,21 +53,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_password_reset_template_html, - self.config.email_password_reset_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, ) async def on_POST(self, request): @@ -169,9 +159,8 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], + self._failure_email_template = ( + self.config.email_password_reset_template_failure_html ) async def on_GET(self, request, medium): @@ -214,14 +203,14 @@ class PasswordResetSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_password_reset_template_success_html + html = self.config.email_password_reset_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) @@ -411,19 +400,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_add_threepid_template_html, - self.config.email_add_threepid_template_text, - ], - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, ) async def on_POST(self, request): @@ -578,9 +559,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_add_threepid_template_failure_html], + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html ) async def on_GET(self, request): @@ -631,7 +611,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f808175698..7290fd0756 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -44,7 +44,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -81,23 +81,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - from synapse.push.mailer import Mailer, load_jinja2_templates - - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_registration_template_html, - self.config.email_registration_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, ) async def on_POST(self, request): @@ -262,15 +250,8 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], + self._failure_email_template = ( + self.config.email_registration_template_failure_html ) async def on_GET(self, request, medium): @@ -318,7 +299,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/tests/config/test_base.py b/tests/config/test_base.py new file mode 100644 index 0000000000..42ee5f56d9 --- /dev/null +++ b/tests/config/test_base.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import tempfile + +from synapse.config import ConfigError +from synapse.util.stringutils import random_string + +from tests import unittest + + +class BaseConfigTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.hs = hs + + def test_loading_missing_templates(self): + # Use a temporary directory that exists on the system, but that isn't likely to + # contain template files + with tempfile.TemporaryDirectory() as tmp_dir: + # Attempt to load an HTML template from our custom template directory + template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + + # If no errors, we should've gotten the default template instead + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"error_description": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_custom_templates(self): + # Use a temporary directory that exists on the system + with tempfile.TemporaryDirectory() as tmp_dir: + # Create a temporary bogus template file + with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template: + # Get temporary file's filename + template_filename = os.path.basename(tmp_template.name) + + # Write a custom HTML template + contents = b"{{ test_variable }}" + tmp_template.write(contents) + tmp_template.flush() + + # Attempt to load the template from our custom template directory + template = ( + self.hs.config.read_templates([template_filename], tmp_dir) + )[0] + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"test_variable": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_template_from_nonexistent_custom_directory(self): + with self.assertRaises(ConfigError): + self.hs.config.read_templates( + ["some_filename.html"], "a_nonexistent_directory" + ) -- cgit 1.5.1