diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 82a72dc34f..10f27e4378 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
@@ -25,16 +26,16 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import register_cache
+from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-logger = logging.getLogger(__name__)
-
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
-rules_by_room = {}
+logger = logging.getLogger(__name__)
push_rules_invalidation_counter = Counter(
@@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
room at once.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- async def _get_rules_for_event(self, event, context):
+ async def _get_rules_for_event(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
@@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
return rules_by_user
@lru_cache()
- def _get_rules_for_room(self, room_id):
+ def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id
-
- Returns:
- RulesForRoom
"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
@@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
self.room_push_rule_cache_metrics,
)
- async def _get_power_levels_and_sender_level(self, event, context):
+ async def _get_power_levels_and_sender_level(
+ self, event: EventBase, context: EventContext
+ ) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
- pl_event = await self.store.get_event(pl_event_id)
- auth_events = {POWER_KEY: pl_event}
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_dict = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def action_for_event_by_user(self, event, context) -> None:
+ async def action_for_event_by_user(
+ self, event: EventBase, context: EventContext
+ ) -> None:
"""Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the
event_push_actions_staging table.
@@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
- actions_by_user = {}
+ actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
room_members = await self.store.get_joined_users_from_context(event, context)
@@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels
)
- condition_cache = {}
+ condition_cache = {} # type: Dict[str, bool]
for uid, rules in rules_by_user.items():
if event.sender == uid:
@@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
)
-def _condition_checker(evaluator, conditions, uid, display_name, cache):
+def _condition_checker(
+ evaluator: PushRuleEvaluatorForEvent,
+ conditions: List[dict],
+ uid: str,
+ display_name: str,
+ cache: Dict[str, bool],
+) -> bool:
for cond in conditions:
_id = cond.get("_id", None)
if _id:
@@ -277,15 +286,19 @@ class RulesForRoom:
"""
def __init__(
- self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+ self,
+ hs: "HomeServer",
+ room_id: str,
+ rules_for_room_cache: LruCache,
+ room_push_rule_cache_metrics: CacheMetric,
):
"""
Args:
- hs (HomeServer)
- room_id (str)
+ hs: The HomeServer object.
+ room_id: The room ID.
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
- room_push_rule_cache_metrics (CacheMetric)
+ room_push_rule_cache_metrics: The metrics object
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
@@ -294,8 +307,10 @@ class RulesForRoom:
self.linearizer = Linearizer(name="rules_for_room")
- self.member_map = {} # event_id -> (user_id, state)
- self.rules_by_user = {} # user_id -> rules
+ # event_id -> (user_id, state)
+ self.member_map = {} # type: Dict[str, Tuple[str, str]]
+ # user_id -> rules
+ self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
@@ -315,7 +330,7 @@ class RulesForRoom:
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
- self.uninteresting_user_set = set()
+ self.uninteresting_user_set = set() # type: Set[str]
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
@@ -325,7 +340,9 @@ class RulesForRoom:
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
- async def get_rules(self, event, context):
+ async def get_rules(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are
currently in the room.
"""
@@ -356,6 +373,8 @@ class RulesForRoom:
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
+ # Ensure the state IDs exist.
+ assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
@@ -420,18 +439,23 @@ class RulesForRoom:
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
- self, ret_rules_by_user, member_event_ids, state_group, event
- ):
+ self,
+ ret_rules_by_user: Dict[str, list],
+ member_event_ids: Dict[str, str],
+ state_group: Optional[int],
+ event: EventBase,
+ ) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
- ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+ ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
- member_event_ids (dict): Dict of user id to event id for membership events
+ member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
+ event: The event we are currently computing push rules for.
"""
sequence = self.sequence
@@ -449,19 +473,19 @@ class RulesForRoom:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- user_ids = {
+ joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
- logger.debug("Joined: %r", user_ids)
+ logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
- user_ids = list(filter(self.is_mine_id, user_ids))
+ user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
@@ -473,7 +497,7 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
- def invalidate_all(self):
+ def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
@@ -485,7 +509,7 @@ class RulesForRoom:
self.rules_by_user = {}
push_rules_invalidation_counter.inc()
- def update_cache(self, sequence, members, rules_by_user, state_group):
+ def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user = rules_by_user
@@ -506,7 +530,7 @@ class _Invalidation:
cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str)
- def __call__(self):
+ def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
|