From 8144bc26a7432463b7e70f9c03198d4724952522 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 12:21:34 -0400 Subject: Convert push to async/await. (#7948) --- synapse/push/action_generator.py | 7 +--- synapse/push/bulk_push_rule_evaluator.py | 62 ++++++++++++---------------- synapse/push/httppusher.py | 58 ++++++++++++-------------- synapse/push/presentable_names.py | 15 +++---- synapse/push/push_tools.py | 22 ++++------ synapse/push/pusherpool.py | 70 +++++++++++++------------------- 6 files changed, 95 insertions(+), 139 deletions(-) (limited to 'synapse/push') diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 1ffd5e2df3..0d23142653 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.metrics import Measure from .bulk_push_rule_evaluator import BulkPushRuleEvaluator @@ -37,7 +35,6 @@ class ActionGenerator(object): # event stream, so we just run the rules for a client with no profile # tag (ie. we just need all the users). - @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, context): + async def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user(event, context) + await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 472ddf9f7d..04b9d8ac82 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,6 @@ from collections import namedtuple from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY @@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object): resizable=False, ) - @defer.inlineCallbacks - def _get_rules_for_event(self, event, context): + async def _get_rules_for_event(self, event, context): """This gets the rules for all users in the room at the time of the event, as well as the push rules for the invitee if the event is an invite. @@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = yield self._get_rules_for_room(room_id) + rules_for_room = await self._get_rules_for_room(room_id) - rules_by_user = yield rules_for_room.get_rules(event, context) + rules_by_user = await rules_for_room.get_rules(event, context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = yield self.store.user_has_pusher(invited) + has_pusher = await self.store.user_has_pusher(invited) if has_pusher: rules_by_user = dict(rules_by_user) - rules_by_user[invited] = yield self.store.get_push_rules_for_user( + rules_by_user[invited] = await self.store.get_push_rules_for_user( invited ) @@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics, ) - @defer.inlineCallbacks - def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids() + async def _get_power_levels_and_sender_level(self, event, context): + prev_state_ids = await context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case - pl_event = yield self.store.get_event(pl_event_id) + pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = yield self.auth.compute_auth_events( + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level - @defer.inlineCallbacks - def action_for_event_by_user(self, event, context): + async def action_for_event_by_user(self, event, context) -> None: """Given an event and context, evaluate the push rules and insert the results into the event_push_actions_staging table. - - Returns: - Deferred """ - rules_by_user = yield self._get_rules_for_event(event, context) + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context(event, context) + room_members = await self.store.get_joined_users_from_context(event, context) ( power_levels, sender_power_level, - ) = yield self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels @@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object): continue if not event.is_state(): - is_ignored = yield self.store.is_ignored_by(event.sender, uid) + is_ignored = await self.store.is_ignored_by(event.sender, uid) if is_ignored: continue @@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -274,8 +266,7 @@ class RulesForRoom(object): # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) - @defer.inlineCallbacks - def get_rules(self, event, context): + async def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ @@ -286,7 +277,7 @@ class RulesForRoom(object): self.room_push_rule_cache_metrics.inc_hits() return self.rules_by_user - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() @@ -304,9 +295,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield defer.ensureDeferred( - context.get_current_state_ids() - ) + current_state_ids = await context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) @@ -353,7 +342,7 @@ class RulesForRoom(object): # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) - yield self._update_rules_with_member_event_ids( + await self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) else: @@ -371,8 +360,7 @@ class RulesForRoom(object): ) return ret_rules_by_user - @defer.inlineCallbacks - def _update_rules_with_member_event_ids( + async def _update_rules_with_member_event_ids( self, ret_rules_by_user, member_event_ids, state_group, event ): """Update the partially filled rules_by_user dict by fetching rules for @@ -388,7 +376,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) + rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} @@ -410,7 +398,7 @@ class RulesForRoom(object): logger.debug("Joined: %r", interested_in_user_ids) - if_users_with_pushers = yield self.store.get_if_users_have_pushers( + if_users_with_pushers = await self.store.get_if_users_have_pushers( interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) @@ -420,7 +408,7 @@ class RulesForRoom(object): logger.debug("With pushers: %r", user_ids) - users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + users_with_receipts = await self.store.get_users_with_read_receipts_in_room( self.room_id, on_invalidate=self.invalidate_all_cb ) @@ -431,7 +419,7 @@ class RulesForRoom(object): if uid in interested_in_user_ids: user_ids.add(uid) - rules_by_user = yield self.store.bulk_get_push_rules( + rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 2fac07593b..4c469efb20 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -17,7 +17,6 @@ import logging from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.api.constants import EventTypes @@ -128,12 +127,11 @@ class HttpPusher(object): # but currently that's the only type of receipt anyway... run_as_background_process("http_pusher.on_new_receipts", self._update_badge) - @defer.inlineCallbacks - def _update_badge(self): + async def _update_badge(self): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - yield self._send_badge(badge) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + await self._send_badge(badge) def on_timer(self): self._start_processing() @@ -152,8 +150,7 @@ class HttpPusher(object): run_as_background_process("httppush.process", self._process) - @defer.inlineCallbacks - def _process(self): + async def _process(self): # we should never get here if we are already processing assert not self._is_processing @@ -164,7 +161,7 @@ class HttpPusher(object): while True: starting_max_ordering = self.max_stream_ordering try: - yield self._unsafe_process() + await self._unsafe_process() except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: @@ -172,8 +169,7 @@ class HttpPusher(object): finally: self._is_processing = False - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): """ Looks for unset notifications and dispatch them, in order Never call this directly: use _process which will only allow this to @@ -181,7 +177,7 @@ class HttpPusher(object): """ fn = self.store.get_unread_push_actions_for_user_in_range_for_http - unprocessed = yield fn( + unprocessed = await fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -203,13 +199,13 @@ class HttpPusher(object): "app_display_name": self.app_display_name, }, ): - processed = yield self._process_one(push_action) + processed = await self._process_one(push_action) if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, self.user_id, @@ -224,14 +220,14 @@ class HttpPusher(object): if self.failing_since: self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) @@ -250,7 +246,7 @@ class HttpPusher(object): ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, @@ -263,7 +259,7 @@ class HttpPusher(object): return self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: @@ -276,18 +272,17 @@ class HttpPusher(object): ) break - @defer.inlineCallbacks - def _process_one(self, push_action): + async def _process_one(self, push_action): if "notify" not in push_action["actions"]: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action["event_id"], allow_none=True) + event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: return True # It's been redacted - rejected = yield self.dispatch_push(event, tweaks, badge) + rejected = await self.dispatch_push(event, tweaks, badge) if rejected is False: return False @@ -301,11 +296,10 @@ class HttpPusher(object): ) else: logger.info("Pushkey %s was rejected: removing", pk) - yield self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self.hs.remove_pusher(self.app_id, pk, self.user_id) return True - @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): + async def _build_notification_dict(self, event, tweaks, badge): priority = "low" if ( event.type == EventTypes.Encrypted @@ -335,7 +329,7 @@ class HttpPusher(object): } return d - ctx = yield push_tools.get_context_for_event( + ctx = await push_tools.get_context_for_event( self.storage, self.state_handler, event, self.user_id ) @@ -377,13 +371,12 @@ class HttpPusher(object): return d - @defer.inlineCallbacks - def dispatch_push(self, event, tweaks, badge): - notification_dict = yield self._build_notification_dict(event, tweaks, badge) + async def dispatch_push(self, event, tweaks, badge): + notification_dict = await self._build_notification_dict(event, tweaks, badge) if not notification_dict: return [] try: - resp = yield self.http_client.post_json_get_json( + resp = await self.http_client.post_json_get_json( self.url, notification_dict ) except Exception as e: @@ -400,8 +393,7 @@ class HttpPusher(object): rejected = resp["rejected"] return rejected - @defer.inlineCallbacks - def _send_badge(self, badge): + async def _send_badge(self, badge): """ Args: badge (int): number of unread messages @@ -424,7 +416,7 @@ class HttpPusher(object): } } try: - yield self.http_client.post_json_get_json(self.url, d) + await self.http_client.post_json_get_json(self.url, d) http_badges_processed_counter.inc() except Exception as e: logger.warning( diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0644a13cfc..d8f4a453cd 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$") ALL_ALONE = "Empty Room" -@defer.inlineCallbacks -def calculate_room_name( +async def calculate_room_name( store, room_state_ids, user_id, @@ -53,7 +50,7 @@ def calculate_room_name( """ # does it have a name? if (EventTypes.Name, "") in room_state_ids: - m_room_name = yield store.get_event( + m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: @@ -61,7 +58,7 @@ def calculate_room_name( # does it have a canonical alias? if (EventTypes.CanonicalAlias, "") in room_state_ids: - canon_alias = yield store.get_event( + canon_alias = await store.get_event( room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True ) if ( @@ -81,7 +78,7 @@ def calculate_room_name( my_member_event = None if (EventTypes.Member, user_id) in room_state_ids: - my_member_event = yield store.get_event( + my_member_event = await store.get_event( room_state_ids[(EventTypes.Member, user_id)], allow_none=True ) @@ -90,7 +87,7 @@ def calculate_room_name( and my_member_event.content["membership"] == "invite" ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: - inviter_member_event = yield store.get_event( + inviter_member_event = await store.get_event( room_state_ids[(EventTypes.Member, my_member_event.sender)], allow_none=True, ) @@ -107,7 +104,7 @@ def calculate_room_name( # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: - member_events = yield store.get_events( + member_events = await store.get_events( list(room_state_bytype_ids[EventTypes.Member].values()) ) all_members = [ diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 5dae4648c0..d0145666bf 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage -@defer.inlineCallbacks -def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_local_user(user_id) - joins = yield store.get_rooms_for_user(user_id) +async def get_badge_count(store, user_id): + invites = await store.get_invited_rooms_for_local_user(user_id) + joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -32,7 +29,7 @@ def get_badge_count(store, user_id): if room_id in my_receipts_by_room: last_unread_event_id = my_receipts_by_room[room_id] - notifs = yield ( + notifs = await ( store.get_unread_event_push_actions_by_room_for_user( room_id, user_id, last_unread_event_id ) @@ -43,23 +40,22 @@ def get_badge_count(store, user_id): return badge -@defer.inlineCallbacks -def get_context_for_event(storage: Storage, state_handler, ev, user_id): +async def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) + room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room - name = yield calculate_room_name( + name = await calculate_room_name( storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield storage.main.get_event(sender_state_event_id) + sender_state_event = await storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2456f12f46..3c3262a88c 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union from prometheus_client import Gauge -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -52,7 +50,7 @@ class PusherPool: Note that it is expected that each pusher will have its own 'processing' loop which will send out the notifications in the background, rather than blocking until the notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and - Pusher.on_new_receipts are not expected to return deferreds. + Pusher.on_new_receipts are not expected to return awaitables. """ def __init__(self, hs: "HomeServer"): @@ -77,8 +75,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -94,7 +91,7 @@ class PusherPool: """Creates a new pusher and adds it to the pool Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ time_now_msec = self.clock.time_msec() @@ -124,9 +121,9 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() + last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() - yield self.store.add_pusher( + await self.store.add_pusher( user_id=user_id, access_token=access_token, kind=kind, @@ -140,15 +137,14 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) return pusher - @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user( + async def remove_pushers_by_app_id_and_pushkey_not_user( self, app_id, pushkey, not_user_id ): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: if p["user_name"] != not_user_id: logger.info( @@ -157,10 +153,9 @@ class PusherPool: pushkey, p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token(self, user_id, access_tokens): """Remove the pushers for a given user corresponding to a set of access_tokens. @@ -173,7 +168,7 @@ class PusherPool: return tokens = set(access_tokens) - for p in (yield self.store.get_pushers_by_user_id(user_id)): + for p in await self.store.get_pushers_by_user_id(user_id): if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", @@ -181,16 +176,15 @@ class PusherPool: p["pushkey"], p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, min_stream_id, max_stream_id): if not self.pushers: # nothing to do here. return try: - users_affected = yield self.store.get_push_action_users_in_range( + users_affected = await self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) @@ -202,8 +196,7 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): if not self.pushers: # nothing to do here. return @@ -211,7 +204,7 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - users_affected = yield self.store.get_users_sent_receipts_between( + users_affected = await self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) @@ -223,12 +216,11 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - @defer.inlineCallbacks - def start_pusher_by_id(self, app_id, pushkey, user_id): + async def start_pusher_by_id(self, app_id, pushkey, user_id): """Look up the details for the given pusher, and start it Returns: - Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any + EmailPusher|HttpPusher|None: The pusher started, if any """ if not self._should_start_pushers: return @@ -236,7 +228,7 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: @@ -245,34 +237,29 @@ class PusherPool: pusher = None if pusher_dict: - pusher = yield self._start_pusher(pusher_dict) + pusher = await self._start_pusher(pusher_dict) return pusher - @defer.inlineCallbacks - def _start_pushers(self): + async def _start_pushers(self) -> None: """Start all the pushers - - Returns: - Deferred """ - pushers = yield self.store.get_all_pushers() + pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. - yield concurrently_execute(self._start_pusher, pushers, 10) + await concurrently_execute(self._start_pusher, pushers, 10) logger.info("Started pushers") - @defer.inlineCallbacks - def _start_pusher(self, pusherdict): + async def _start_pusher(self, pusherdict): """Start the given pusher Args: pusherdict (dict): dict with the values pulled from the db table Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ if not self._pusher_shard_config.should_handle( self._instance_name, pusherdict["user_name"] @@ -315,7 +302,7 @@ class PusherPool: user_id = pusherdict["user_name"] last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: - have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( + have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering ) else: @@ -327,8 +314,7 @@ class PusherPool: return p - @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id, pushkey, user_id): appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) @@ -340,6 +326,6 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) -- cgit 1.5.1 From 8dff4a12424cda9e4abaa5f2905d58aa6e723777 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 29 Jul 2020 18:26:55 +0100 Subject: Re-implement unread counts (#7736) --- changelog.d/7736.feature | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 + synapse/push/push_tools.py | 17 +-- synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/data_stores/main/cache.py | 1 + synapse/storage/data_stores/main/events.py | 48 ++++++- synapse/storage/data_stores/main/events_worker.py | 86 ++++++++++- .../main/schema/delta/58/12unread_messages.sql | 18 +++ tests/rest/client/v1/utils.py | 20 +++ tests/rest/client/v2_alpha/test_sync.py | 157 ++++++++++++++++++++- 11 files changed, 339 insertions(+), 18 deletions(-) create mode 100644 changelog.d/7736.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql (limited to 'synapse/push') diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature new file mode 100644 index 0000000000..c97864677a --- /dev/null +++ b/changelog.d/7736.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 22a6abd7d2..bee525197f 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], + "events": ["processed", "outlier", "contains_url", "count_as_unread"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ebd3e98105..eaa4eeadf7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,6 +103,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1886,6 +1887,10 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] + + unread_count = await self.store.get_unread_message_count_for_user( + room_id, sync_config.user.to_string(), + ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1894,6 +1899,7 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666bf..bc8f71916b 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,22 +21,13 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") - badge = len(invites) for room_id in joins: - if room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[room_id] - - notifs = await ( - store.get_unread_event_push_actions_by_room_for_user( - room_id, user_id, last_unread_event_id - ) - ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + unread_count = await store.get_unread_message_count_for_user(room_id, user_id) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if unread_count else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd63..3f5bf75e59 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index f39f556c20..edc3624fed 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) + self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 6f2e0d15cc..0c9c02afa1 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -53,6 +53,47 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + def encode_json(json_object): """ @@ -196,6 +237,10 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() + self.store.get_unread_message_count_for_user.invalidate_many( + (event.room_id,), + ) + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -817,8 +862,9 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), + "count_as_unread": should_count_as_unread(event, context), } - for event, _ in events_and_contexts + for event, context in events_and_contexts ], ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e812c67078..b03b259636 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import ( + Cache, + _CacheContext, + cached, + cachedInlineCallbacks, +) from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + @cached(tree=True, cache_context=True) + async def get_unread_message_count_for_user( + self, room_id: str, user_id: str, cache_context: _CacheContext, + ) -> int: + """Retrieve the count of unread messages for the given room and user. + + Args: + room_id: The ID of the room to count unread messages in. + user_id: The ID of the user to count unread messages for. + + Returns: + The number of unread messages for the given user in the given room. + """ + with Measure(self._clock, "get_unread_message_count_for_user"): + last_read_event_id = await self.get_last_receipt_event_id_for_user( + user_id=user_id, + room_id=room_id, + receipt_type="m.read", + on_invalidate=cache_context.invalidate, + ) + + return await self.db.runInteraction( + "get_unread_message_count_for_user", + self._get_unread_message_count_for_user_txn, + user_id, + room_id, + last_read_event_id, + ) + + def _get_unread_message_count_for_user_txn( + self, + txn: Cursor, + user_id: str, + room_id: str, + last_read_event_id: Optional[str], + ) -> int: + if last_read_event_id: + # Get the stream ordering for the last read event. + stream_ordering = self.db.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, + retcol="stream_ordering", + ) + else: + # If there's no read receipt for that room, it probably means the user hasn't + # opened it yet, in which case use the stream ID of their join event. + # We can't just set it to 0 otherwise messages from other local users from + # before this user joined will be counted as well. + txn.execute( + """ + SELECT stream_ordering FROM local_current_membership + LEFT JOIN events USING (event_id, room_id) + WHERE membership = 'join' + AND user_id = ? + AND room_id = ? + """, + (user_id, room_id), + ) + row = txn.fetchone() + + if row is None: + return 0 + + stream_ordering = row[0] + + # Count the messages that qualify as unread after the stream ordering we've just + # retrieved. + sql = """ + SELECT COUNT(*) FROM events + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread + """ + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + + return row[0] if row else 0 + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql new file mode 100644 index 0000000000..531b532c73 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Store a boolean value in the events table for whether the event should be counted in +-- the unread_count property of sync responses. +ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 22d734e763..7f8252330a 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -143,6 +143,26 @@ class RestHelper(object): return channel.json_body + def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body + def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1bd..a31e44c97e 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] -- cgit 1.5.1