diff options
Diffstat (limited to 'synapse/push')
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 161 | ||||
-rw-r--r-- | synapse/push/emailpusher.py | 9 | ||||
-rw-r--r-- | synapse/push/pusherpool.py | 8 |
3 files changed, 105 insertions, 73 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 50b470c310..350646f458 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -106,6 +106,10 @@ class BulkPushRuleEvaluator: self.store = hs.get_datastore() self.auth = hs.get_auth() + # Used by `RulesForRoom` to ensure only one thing mutates the cache at a + # time. Keyed off room_id. + self._rules_linearizer = Linearizer(name="rules_for_room") + self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", @@ -123,7 +127,16 @@ class BulkPushRuleEvaluator: dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = self._get_rules_for_room(room_id) + + rules_for_room_data = self._get_rules_for_room(room_id) + rules_for_room = RulesForRoom( + hs=self.hs, + room_id=room_id, + rules_for_room_cache=self._get_rules_for_room.cache, + room_push_rule_cache_metrics=self.room_push_rule_cache_metrics, + linearizer=self._rules_linearizer, + cached_data=rules_for_room_data, + ) rules_by_user = await rules_for_room.get_rules(event, context) @@ -142,17 +155,12 @@ class BulkPushRuleEvaluator: return rules_by_user @lru_cache() - def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": - """Get the current RulesForRoom object for the given room id""" - # It's important that RulesForRoom gets added to self._get_rules_for_room.cache + def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData": + """Get the current RulesForRoomData object for the given room id""" + # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) - return RulesForRoom( - self.hs, - room_id, - self._get_rules_for_room.cache, - self.room_push_rule_cache_metrics, - ) + return RulesForRoomData() async def _get_power_levels_and_sender_level( self, event: EventBase, context: EventContext @@ -282,11 +290,49 @@ def _condition_checker( return True +@attr.s(slots=True) +class RulesForRoomData: + """The data stored in the cache by `RulesForRoom`. + + We don't store `RulesForRoom` directly in the cache as we want our caches to + *only* include data, and not references to e.g. the data stores. + """ + + # event_id -> (user_id, state) + member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict) + # user_id -> rules + rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict) + + # The last state group we updated the caches for. If the state_group of + # a new event comes along, we know that we can just return the cached + # result. + # On invalidation of the rules themselves (if the user changes them), + # we invalidate everything and set state_group to `object()` + state_group = attr.ib(type=Union[object, int], factory=object) + + # A sequence number to keep track of when we're allowed to update the + # cache. We bump the sequence number when we invalidate the cache. If + # the sequence number changes while we're calculating stuff we should + # not update the cache with it. + sequence = attr.ib(type=int, default=0) + + # A cache of user_ids that we *know* aren't interesting, e.g. user_ids + # owned by AS's, or remote users, etc. (I.e. users we will never need to + # calculate push for) + # These never need to be invalidated as we will never set up push for + # them. + uninteresting_user_set = attr.ib(type=Set[str], factory=set) + + class RulesForRoom: """Caches push rules for users in a room. This efficiently handles users joining/leaving the room by not invalidating the entire cache for the room. + + A new instance is constructed for each call to + `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from + previous calls passed in. """ def __init__( @@ -295,6 +341,8 @@ class RulesForRoom: room_id: str, rules_for_room_cache: LruCache, room_push_rule_cache_metrics: CacheMetric, + linearizer: Linearizer, + cached_data: RulesForRoomData, ): """ Args: @@ -303,38 +351,21 @@ class RulesForRoom: rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics: The metrics object + linearizer: The linearizer used to ensure only one thing mutates + the cache at a time. Keyed off room_id + cached_data: Cached data from previous calls to `self.get_rules`, + can be mutated. """ self.room_id = room_id self.is_mine_id = hs.is_mine_id self.store = hs.get_datastore() self.room_push_rule_cache_metrics = room_push_rule_cache_metrics - self.linearizer = Linearizer(name="rules_for_room") - - # event_id -> (user_id, state) - self.member_map = {} # type: Dict[str, Tuple[str, str]] - # user_id -> rules - self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]] - - # The last state group we updated the caches for. If the state_group of - # a new event comes along, we know that we can just return the cached - # result. - # On invalidation of the rules themselves (if the user changes them), - # we invalidate everything and set state_group to `object()` - self.state_group = object() - - # A sequence number to keep track of when we're allowed to update the - # cache. We bump the sequence number when we invalidate the cache. If - # the sequence number changes while we're calculating stuff we should - # not update the cache with it. - self.sequence = 0 - - # A cache of user_ids that we *know* aren't interesting, e.g. user_ids - # owned by AS's, or remote users, etc. (I.e. users we will never need to - # calculate push for) - # These never need to be invalidated as we will never set up push for - # them. - self.uninteresting_user_set = set() # type: Set[str] + # Used to ensure only one thing mutates the cache at a time. Keyed off + # room_id. + self.linearizer = linearizer + + self.data = cached_data # We need to be clever on the invalidating caches callbacks, as # otherwise the invalidation callback holds a reference to the object, @@ -352,25 +383,25 @@ class RulesForRoom: """ state_group = context.state_group - if state_group and self.state_group == state_group: + if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - return self.rules_by_user + return self.data.rules_by_user - with (await self.linearizer.queue(())): - if state_group and self.state_group == state_group: + with (await self.linearizer.queue(self.room_id)): + if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() - return self.rules_by_user + return self.data.rules_by_user self.room_push_rule_cache_metrics.inc_misses() ret_rules_by_user = {} missing_member_event_ids = {} - if state_group and self.state_group == context.prev_group: + if state_group and self.data.state_group == context.prev_group: # If we have a simple delta then we can reuse most of the previous # results. - ret_rules_by_user = self.rules_by_user + ret_rules_by_user = self.data.rules_by_user current_state_ids = context.delta_ids push_rules_delta_state_cache_metric.inc_hits() @@ -393,24 +424,24 @@ class RulesForRoom: if typ != EventTypes.Member: continue - if user_id in self.uninteresting_user_set: + if user_id in self.data.uninteresting_user_set: continue if not self.is_mine_id(user_id): - self.uninteresting_user_set.add(user_id) + self.data.uninteresting_user_set.add(user_id) continue if self.store.get_if_app_services_interested_in_user(user_id): - self.uninteresting_user_set.add(user_id) + self.data.uninteresting_user_set.add(user_id) continue event_id = current_state_ids[key] - res = self.member_map.get(event_id, None) + res = self.data.member_map.get(event_id, None) if res: user_id, state = res if state == Membership.JOIN: - rules = self.rules_by_user.get(user_id, None) + rules = self.data.rules_by_user.get(user_id, None) if rules: ret_rules_by_user[user_id] = rules continue @@ -430,7 +461,7 @@ class RulesForRoom: else: # The push rules didn't change but lets update the cache anyway self.update_cache( - self.sequence, + self.data.sequence, members={}, # There were no membership changes rules_by_user=ret_rules_by_user, state_group=state_group, @@ -461,7 +492,7 @@ class RulesForRoom: for. Used when updating the cache. event: The event we are currently computing push rules for. """ - sequence = self.sequence + sequence = self.data.sequence rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) @@ -501,23 +532,11 @@ class RulesForRoom: self.update_cache(sequence, members, ret_rules_by_user, state_group) - def invalidate_all(self) -> None: - # Note: Don't hand this function directly to an invalidation callback - # as it keeps a reference to self and will stop this instance from being - # GC'd if it gets dropped from the rules_to_user cache. Instead use - # `self.invalidate_all_cb` - logger.debug("Invalidating RulesForRoom for %r", self.room_id) - self.sequence += 1 - self.state_group = object() - self.member_map = {} - self.rules_by_user = {} - push_rules_invalidation_counter.inc() - def update_cache(self, sequence, members, rules_by_user, state_group) -> None: - if sequence == self.sequence: - self.member_map.update(members) - self.rules_by_user = rules_by_user - self.state_group = state_group + if sequence == self.data.sequence: + self.data.member_map.update(members) + self.data.rules_by_user = rules_by_user + self.data.state_group = state_group @attr.attrs(slots=True, frozen=True) @@ -535,6 +554,10 @@ class _Invalidation: room_id = attr.ib(type=str) def __call__(self) -> None: - rules = self.cache.get(self.room_id, None, update_metrics=False) - if rules: - rules.invalidate_all() + rules_data = self.cache.get(self.room_id, None, update_metrics=False) + if rules_data: + rules_data.sequence += 1 + rules_data.state_group = object() + rules_data.member_map = {} + rules_data.rules_by_user = {} + push_rules_invalidation_counter.inc() diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index cd89b54305..99a18874d1 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -19,8 +19,9 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.interfaces import IDelayedCall from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.push import Pusher, PusherConfig, ThrottleParams +from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer +from synapse.util.threepids import validate_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -71,6 +72,12 @@ class EmailPusher(Pusher): self._is_processing = False + # Make sure that the email is valid. + try: + validate_email(self.email) + except ValueError: + raise PusherConfigException("Invalid email") + def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 564a5ed0df..579fcdf472 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -62,7 +62,9 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity = hs.config.account_validity + self._account_validity_enabled = ( + hs.config.account_validity.account_validity_enabled + ) # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config @@ -236,7 +238,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) @@ -266,7 +268,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) |