diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 3f75d3f921..fe09d50d55 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
-from .bulk_push_rule_evaluator import evaluator_for_event
+from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure
@@ -24,11 +24,12 @@ import logging
logger = logging.getLogger(__name__)
-class ActionGenerator:
+class ActionGenerator(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
# also actions for a client with no profile tag for each user.
@@ -38,16 +39,11 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
- with Measure(self.clock, "evaluator_for_event"):
- bulk_evaluator = yield evaluator_for_event(
- event, self.hs, self.store, context
- )
-
with Measure(self.clock, "action_for_event_by_user"):
- actions_by_user = yield bulk_evaluator.action_for_event_by_user(
+ actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
event, context
)
context.push_actions = [
- (uid, actions) for uid, actions in actions_by_user.items()
+ (uid, actions) for uid, actions in actions_by_user.iteritems()
]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index f943ff640f..9a96e6fe8f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,60 +19,83 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients_context
+from synapse.api.constants import EventTypes, Membership
+from synapse.util.caches.descriptors import cached
+from synapse.util.async import Linearizer
+from collections import namedtuple
-logger = logging.getLogger(__name__)
+logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, context):
- rules_by_user = yield store.bulk_get_push_rules_for_room(
- event, context
- )
-
- # if this event is an invite event, we may need to run rules for the user
- # who's been invited, otherwise they won't get told they've been invited
- if event.type == 'm.room.member' and event.content['membership'] == 'invite':
- invited_user = event.state_key
- if invited_user and hs.is_mine_id(invited_user):
- has_pusher = yield store.user_has_pusher(invited_user)
- if has_pusher:
- rules_by_user = dict(rules_by_user)
- rules_by_user[invited_user] = yield store.get_push_rules_for_user(
- invited_user
- )
- defer.returnValue(BulkPushRuleEvaluator(
- event.room_id, rules_by_user, store
- ))
+rules_by_room = {}
-class BulkPushRuleEvaluator:
+class BulkPushRuleEvaluator(object):
+ """Calculates the outcome of push rules for an event for all users in the
+ room at once.
"""
- Runs push rules for all users in a room.
- This is faster than running PushRuleEvaluator for each user because it
- fetches all the rules for all the users in one (batched) db query
- rather than doing multiple queries per-user. It currently uses
- the same logic to run the actual rules, but could be optimised further
- (see https://matrix.org/jira/browse/SYN-562)
- """
- def __init__(self, room_id, rules_by_user, store):
- self.room_id = room_id
- self.rules_by_user = rules_by_user
- self.store = store
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def _get_rules_for_event(self, event, context):
+ """This gets the rules for all users in the room at the time of the event,
+ as well as the push rules for the invitee if the event is an invite.
+
+ Returns:
+ dict of user_id -> push_rules
+ """
+ room_id = event.room_id
+ rules_for_room = self._get_rules_for_room(room_id)
+
+ rules_by_user = yield rules_for_room.get_rules(event, context)
+
+ # if this event is an invite event, we may need to run rules for the user
+ # who's been invited, otherwise they won't get told they've been invited
+ if event.type == 'm.room.member' and event.content['membership'] == 'invite':
+ invited = event.state_key
+ if invited and self.hs.is_mine_id(invited):
+ has_pusher = yield self.store.user_has_pusher(invited)
+ if has_pusher:
+ rules_by_user = dict(rules_by_user)
+ rules_by_user[invited] = yield self.store.get_push_rules_for_user(
+ invited
+ )
+
+ defer.returnValue(rules_by_user)
+
+ @cached()
+ def _get_rules_for_room(self, room_id):
+ """Get the current RulesForRoom object for the given room id
+
+ Returns:
+ RulesForRoom
+ """
+ # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
+ # before any lookup methods get called on it as otherwise there may be
+ # a race if invalidate_all gets called (which assumes its in the cache)
+ return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache)
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
+ """Given an event and context, evaluate the push rules and return
+ the results
+
+ Returns:
+ dict of user_id -> action
+ """
+ rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {}
# None of these users can be peeking since this list of users comes
# from the set of users in the room, so we know for sure they're all
# actually in the room.
- user_tuples = [
- (u, False) for u in self.rules_by_user.keys()
- ]
+ user_tuples = [(u, False) for u in rules_by_user]
filtered_by_user = yield filter_events_for_clients_context(
self.store, user_tuples, [event], {event.event_id: context}
@@ -86,7 +109,7 @@ class BulkPushRuleEvaluator:
condition_cache = {}
- for uid, rules in self.rules_by_user.items():
+ for uid, rules in rules_by_user.iteritems():
display_name = None
profile_info = room_members.get(uid)
if profile_info:
@@ -138,3 +161,240 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
return False
return True
+
+
+class RulesForRoom(object):
+ """Caches push rules for users in a room.
+
+ This efficiently handles users joining/leaving the room by not invalidating
+ the entire cache for the room.
+ """
+
+ def __init__(self, hs, room_id, rules_for_room_cache):
+ """
+ Args:
+ hs (HomeServer)
+ room_id (str)
+ rules_for_room_cache(Cache): The cache object that caches these
+ RoomsForUser objects.
+ """
+ self.room_id = room_id
+ self.is_mine_id = hs.is_mine_id
+ self.store = hs.get_datastore()
+
+ self.linearizer = Linearizer(name="rules_for_room")
+
+ self.member_map = {} # event_id -> (user_id, state)
+ self.rules_by_user = {} # user_id -> rules
+
+ # The last state group we updated the caches for. If the state_group of
+ # a new event comes along, we know that we can just return the cached
+ # result.
+ # On invalidation of the rules themselves (if the user changes them),
+ # we invalidate everything and set state_group to `object()`
+ self.state_group = object()
+
+ # A sequence number to keep track of when we're allowed to update the
+ # cache. We bump the sequence number when we invalidate the cache. If
+ # the sequence number changes while we're calculating stuff we should
+ # not update the cache with it.
+ self.sequence = 0
+
+ # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
+ # owned by AS's, or remote users, etc. (I.e. users we will never need to
+ # calculate push for)
+ # These never need to be invalidated as we will never set up push for
+ # them.
+ self.uninteresting_user_set = set()
+
+ # We need to be clever on the invalidating caches callbacks, as
+ # otherwise the invalidation callback holds a reference to the object,
+ # potentially causing it to leak.
+ # To get around this we pass a function that on invalidations looks ups
+ # the RoomsForUser entry in the cache, rather than keeping a reference
+ # to self around in the callback.
+ self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
+
+ @defer.inlineCallbacks
+ def get_rules(self, event, context):
+ """Given an event context return the rules for all users who are
+ currently in the room.
+ """
+ state_group = context.state_group
+
+ with (yield self.linearizer.queue(())):
+ if state_group and self.state_group == state_group:
+ logger.debug("Using cached rules for %r", self.room_id)
+ defer.returnValue(self.rules_by_user)
+
+ ret_rules_by_user = {}
+ missing_member_event_ids = {}
+ if state_group and self.state_group == context.prev_group:
+ # If we have a simple delta then we can reuse most of the previous
+ # results.
+ ret_rules_by_user = self.rules_by_user
+ current_state_ids = context.delta_ids
+ else:
+ current_state_ids = context.current_state_ids
+
+ logger.debug(
+ "Looking for member changes in %r %r", state_group, current_state_ids
+ )
+
+ # Loop through to see which member events we've seen and have rules
+ # for and which we need to fetch
+ for key in current_state_ids:
+ typ, user_id = key
+ if typ != EventTypes.Member:
+ continue
+
+ if user_id in self.uninteresting_user_set:
+ continue
+
+ if not self.is_mine_id(user_id):
+ self.uninteresting_user_set.add(user_id)
+ continue
+
+ if self.store.get_if_app_services_interested_in_user(user_id):
+ self.uninteresting_user_set.add(user_id)
+ continue
+
+ event_id = current_state_ids[key]
+
+ res = self.member_map.get(event_id, None)
+ if res:
+ user_id, state = res
+ if state == Membership.JOIN:
+ rules = self.rules_by_user.get(user_id, None)
+ if rules:
+ ret_rules_by_user[user_id] = rules
+ continue
+
+ # If a user has left a room we remove their push rule. If they
+ # joined then we readd it later in _update_rules_with_member_event_ids
+ ret_rules_by_user.pop(user_id, None)
+ missing_member_event_ids[user_id] = event_id
+
+ if missing_member_event_ids:
+ # If we have some memebr events we haven't seen, look them up
+ # and fetch push rules for them if appropriate.
+ logger.debug("Found new member events %r", missing_member_event_ids)
+ yield self._update_rules_with_member_event_ids(
+ ret_rules_by_user, missing_member_event_ids, state_group, event
+ )
+
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "Returning push rules for %r %r",
+ self.room_id, ret_rules_by_user.keys(),
+ )
+ defer.returnValue(ret_rules_by_user)
+
+ @defer.inlineCallbacks
+ def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
+ state_group, event):
+ """Update the partially filled rules_by_user dict by fetching rules for
+ any newly joined users in the `member_event_ids` list.
+
+ Args:
+ ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+ updated with any new rules.
+ member_event_ids (list): List of event ids for membership events that
+ have happened since the last time we filled rules_by_user
+ state_group: The state group we are currently computing push rules
+ for. Used when updating the cache.
+ """
+ sequence = self.sequence
+
+ rows = yield self.store._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids.values(),
+ retcols=('user_id', 'membership', 'event_id'),
+ keyvalues={},
+ batch_size=500,
+ desc="_get_rules_for_member_event_ids",
+ )
+
+ members = {
+ row["event_id"]: (row["user_id"], row["membership"])
+ for row in rows
+ }
+
+ # If the event is a join event then it will be in current state evnts
+ # map but not in the DB, so we have to explicitly insert it.
+ if event.type == EventTypes.Member:
+ for event_id in member_event_ids.itervalues():
+ if event_id == event.event_id:
+ members[event_id] = (event.state_key, event.membership)
+
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Found members %r: %r", self.room_id, members.values())
+
+ interested_in_user_ids = set(
+ user_id for user_id, membership in members.itervalues()
+ if membership == Membership.JOIN
+ )
+
+ logger.debug("Joined: %r", interested_in_user_ids)
+
+ if_users_with_pushers = yield self.store.get_if_users_have_pushers(
+ interested_in_user_ids,
+ on_invalidate=self.invalidate_all_cb,
+ )
+
+ user_ids = set(
+ uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+ )
+
+ logger.debug("With pushers: %r", user_ids)
+
+ users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
+ self.room_id, on_invalidate=self.invalidate_all_cb,
+ )
+
+ logger.debug("With receipts: %r", users_with_receipts)
+
+ # any users with pushers must be ours: they have pushers
+ for uid in users_with_receipts:
+ if uid in interested_in_user_ids:
+ user_ids.add(uid)
+
+ rules_by_user = yield self.store.bulk_get_push_rules(
+ user_ids, on_invalidate=self.invalidate_all_cb,
+ )
+
+ ret_rules_by_user.update(
+ item for item in rules_by_user.iteritems() if item[0] is not None
+ )
+
+ self.update_cache(sequence, members, ret_rules_by_user, state_group)
+
+ def invalidate_all(self):
+ # Note: Don't hand this function directly to an invalidation callback
+ # as it keeps a reference to self and will stop this instance from being
+ # GC'd if it gets dropped from the rules_to_user cache. Instead use
+ # `self.invalidate_all_cb`
+ logger.debug("Invalidating RulesForRoom for %r", self.room_id)
+ self.sequence += 1
+ self.state_group = object()
+ self.member_map = {}
+ self.rules_by_user = {}
+
+ def update_cache(self, sequence, members, rules_by_user, state_group):
+ if sequence == self.sequence:
+ self.member_map.update(members)
+ self.rules_by_user = rules_by_user
+ self.state_group = state_group
+
+
+class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
+ # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
+ # which namedtuple does for us (i.e. two _CacheContext are the same if
+ # their caches and keys match). This is important in particular to
+ # dedupe when we add callbacks to lru cache nodes, otherwise the number
+ # of callbacks would grow.
+ def __call__(self):
+ rules = self.cache.get(self.room_id, None, update_metrics=False)
+ if rules:
+ rules.invalidate_all()
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c7afd11111..a69dda7b09 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,7 +21,6 @@ import logging
from synapse.util.metrics import Measure
from synapse.util.logcontext import LoggingContext
-from mailer import Mailer
logger = logging.getLogger(__name__)
@@ -56,8 +55,10 @@ class EmailPusher(object):
This shares quite a bit of code with httpusher: it would be good to
factor out the common parts
"""
- def __init__(self, hs, pusherdict):
+ def __init__(self, hs, pusherdict, mailer):
self.hs = hs
+ self.mailer = mailer
+
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict['id']
@@ -73,16 +74,6 @@ class EmailPusher(object):
self.processing = False
- if self.hs.config.email_enable_notifs:
- if 'data' in pusherdict and 'brand' in pusherdict['data']:
- app_name = pusherdict['data']['brand']
- else:
- app_name = self.hs.config.email_app_name
-
- self.mailer = Mailer(self.hs, app_name)
- else:
- self.mailer = None
-
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index f83aa7625c..b5cd9b426a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
class Mailer(object):
- def __init__(self, hs, app_name):
+ def __init__(self, hs, app_name, notif_template_html, notif_template_text):
self.hs = hs
+ self.notif_template_html = notif_template_html
+ self.notif_template_text = notif_template_text
+
self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
- loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
+
logger.info("Created Mailer for app_name %s" % app_name)
- env = jinja2.Environment(loader=loader)
- env.filters["format_ts"] = format_ts_filter
- env.filters["mxc_to_http"] = self.mxc_to_http_filter
- self.notif_template_html = env.get_template(
- self.hs.config.email_notif_template_html
- )
- self.notif_template_text = env.get_template(
- self.hs.config.email_notif_template_text
- )
@defer.inlineCallbacks
def send_notification_mail(self, app_id, user_id, email_address,
@@ -481,28 +475,6 @@ class Mailer(object):
urllib.urlencode(params),
)
- def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
- if value[0:6] != "mxc://":
- return ""
-
- serverAndMediaId = value[6:]
- fragment = None
- if '#' in serverAndMediaId:
- (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
- fragment = "#" + fragment
-
- params = {
- "width": width,
- "height": height,
- "method": resize_method,
- }
- return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
- self.hs.config.public_baseurl,
- serverAndMediaId,
- urllib.urlencode(params),
- fragment or "",
- )
-
def safe_markup(raw_html):
return jinja2.Markup(bleach.linkify(bleach.clean(
@@ -543,3 +515,52 @@ def string_ordinal_total(s):
def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))
+
+
+def load_jinja2_templates(config):
+ """Load the jinja2 email templates from disk
+
+ Returns:
+ (notif_template_html, notif_template_text)
+ """
+ logger.info("loading jinja2")
+
+ loader = jinja2.FileSystemLoader(config.email_template_dir)
+ env = jinja2.Environment(loader=loader)
+ env.filters["format_ts"] = format_ts_filter
+ env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
+
+ notif_template_html = env.get_template(
+ config.email_notif_template_html
+ )
+ notif_template_text = env.get_template(
+ config.email_notif_template_text
+ )
+
+ return notif_template_html, notif_template_text
+
+
+def _create_mxc_to_http_filter(config):
+ def mxc_to_http_filter(value, width, height, resize_method="crop"):
+ if value[0:6] != "mxc://":
+ return ""
+
+ serverAndMediaId = value[6:]
+ fragment = None
+ if '#' in serverAndMediaId:
+ (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+ fragment = "#" + fragment
+
+ params = {
+ "width": width,
+ "height": height,
+ "method": resize_method,
+ }
+ return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+ config.public_baseurl,
+ serverAndMediaId,
+ urllib.urlencode(params),
+ fragment or "",
+ )
+
+ return mxc_to_http_filter
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index de9c33b936..491f27bded 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
+ from synapse.push.mailer import Mailer, load_jinja2_templates
except:
pass
-def create_pusher(hs, pusherdict):
- logger.info("trying to create_pusher for %r", pusherdict)
+class PusherFactory(object):
+ def __init__(self, hs):
+ self.hs = hs
- PUSHER_TYPES = {
- "http": HttpPusher,
- }
+ self.pusher_types = {
+ "http": HttpPusher,
+ }
- logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
- if hs.config.email_enable_notifs:
- PUSHER_TYPES["email"] = EmailPusher
- logger.info("defined email pusher type")
+ logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
+ if hs.config.email_enable_notifs:
+ self.mailers = {} # app_name -> Mailer
- if pusherdict['kind'] in PUSHER_TYPES:
- logger.info("found pusher")
- return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
+ templates = load_jinja2_templates(hs.config)
+ self.notif_template_html, self.notif_template_text = templates
+
+ self.pusher_types["email"] = self._create_email_pusher
+
+ logger.info("defined email pusher type")
+
+ def create_pusher(self, pusherdict):
+ logger.info("trying to create_pusher for %r", pusherdict)
+
+ if pusherdict['kind'] in self.pusher_types:
+ logger.info("found pusher")
+ return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
+
+ def _create_email_pusher(self, _hs, pusherdict):
+ app_name = self._app_name_from_pusherdict(pusherdict)
+ mailer = self.mailers.get(app_name)
+ if not mailer:
+ mailer = Mailer(
+ hs=self.hs,
+ app_name=app_name,
+ notif_template_html=self.notif_template_html,
+ notif_template_text=self.notif_template_text,
+ )
+ self.mailers[app_name] = mailer
+ return EmailPusher(self.hs, pusherdict, mailer)
+
+ def _app_name_from_pusherdict(self, pusherdict):
+ if 'data' in pusherdict and 'brand' in pusherdict['data']:
+ app_name = pusherdict['data']['brand']
+ else:
+ app_name = self.hs.config.email_app_name
+
+ return app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3837be523d..43cb6e9c01 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
-import pusher
+from .pusher import PusherFactory
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.async import run_on_reactor
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
+ self.pusher_factory = PusherFactory(_hs)
self.start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
@@ -48,7 +49,7 @@ class PusherPool:
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
- pusher.create_pusher(self.hs, {
+ self.pusher_factory.create_pusher({
"id": None,
"user_name": user_id,
"kind": kind,
@@ -186,7 +187,7 @@ class PusherPool:
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
try:
- p = pusher.create_pusher(self.hs, pusherdict)
+ p = self.pusher_factory.create_pusher(pusherdict)
except:
logger.exception("Couldn't start a pusher: caught Exception")
continue
|