summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py68
-rw-r--r--synapse/push/emailpusher.py15
-rw-r--r--synapse/push/mailer.py87
-rw-r--r--synapse/push/pusher.py56
-rw-r--r--synapse/push/pusherpool.py7
-rw-r--r--synapse/util/caches/descriptors.py56
-rw-r--r--tests/utils.py1
7 files changed, 201 insertions, 89 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 760d567ca1..9a96e6fe8f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -53,7 +53,7 @@ class BulkPushRuleEvaluator(object):
         room_id = event.room_id
         rules_for_room = self._get_rules_for_room(room_id)
 
-        rules_by_user = yield rules_for_room.get_rules(context)
+        rules_by_user = yield rules_for_room.get_rules(event, context)
 
         # if this event is an invite event, we may need to run rules for the user
         # who's been invited, otherwise they won't get told they've been invited
@@ -200,6 +200,13 @@ class RulesForRoom(object):
         # not update the cache with it.
         self.sequence = 0
 
+        # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
+        # owned by AS's, or remote users, etc. (I.e. users we will never need to
+        # calculate push for)
+        # These never need to be invalidated as we will never set up push for
+        # them.
+        self.uninteresting_user_set = set()
+
         # We need to be clever on the invalidating caches callbacks, as
         # otherwise the invalidation callback holds a reference to the object,
         # potentially causing it to leak.
@@ -209,7 +216,7 @@ class RulesForRoom(object):
         self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
 
     @defer.inlineCallbacks
-    def get_rules(self, context):
+    def get_rules(self, event, context):
         """Given an event context return the rules for all users who are
         currently in the room.
         """
@@ -217,6 +224,7 @@ class RulesForRoom(object):
 
         with (yield self.linearizer.queue(())):
             if state_group and self.state_group == state_group:
+                logger.debug("Using cached rules for %r", self.room_id)
                 defer.returnValue(self.rules_by_user)
 
             ret_rules_by_user = {}
@@ -229,12 +237,30 @@ class RulesForRoom(object):
             else:
                 current_state_ids = context.current_state_ids
 
+            logger.debug(
+                "Looking for member changes in %r %r", state_group, current_state_ids
+            )
+
             # Loop through to see which member events we've seen and have rules
             # for and which we need to fetch
-            for key, event_id in current_state_ids.iteritems():
-                if key[0] != EventTypes.Member:
+            for key in current_state_ids:
+                typ, user_id = key
+                if typ != EventTypes.Member:
+                    continue
+
+                if user_id in self.uninteresting_user_set:
                     continue
 
+                if not self.is_mine_id(user_id):
+                    self.uninteresting_user_set.add(user_id)
+                    continue
+
+                if self.store.get_if_app_services_interested_in_user(user_id):
+                    self.uninteresting_user_set.add(user_id)
+                    continue
+
+                event_id = current_state_ids[key]
+
                 res = self.member_map.get(event_id, None)
                 if res:
                     user_id, state = res
@@ -244,13 +270,6 @@ class RulesForRoom(object):
                             ret_rules_by_user[user_id] = rules
                     continue
 
-                user_id = key[1]
-                if not self.is_mine_id(user_id):
-                    continue
-
-                if self.store.get_if_app_services_interested_in_user(user_id):
-                    continue
-
                 # If a user has left a room we remove their push rule. If they
                 # joined then we readd it later in _update_rules_with_member_event_ids
                 ret_rules_by_user.pop(user_id, None)
@@ -259,15 +278,21 @@ class RulesForRoom(object):
             if missing_member_event_ids:
                 # If we have some memebr events we haven't seen, look them up
                 # and fetch push rules for them if appropriate.
+                logger.debug("Found new member events %r", missing_member_event_ids)
                 yield self._update_rules_with_member_event_ids(
-                    ret_rules_by_user, missing_member_event_ids, state_group
+                    ret_rules_by_user, missing_member_event_ids, state_group, event
                 )
 
+        if logger.isEnabledFor(logging.DEBUG):
+            logger.debug(
+                "Returning push rules for %r %r",
+                self.room_id, ret_rules_by_user.keys(),
+            )
         defer.returnValue(ret_rules_by_user)
 
     @defer.inlineCallbacks
     def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
-                                            state_group):
+                                            state_group, event):
         """Update the partially filled rules_by_user dict by fetching rules for
         any newly joined users in the `member_event_ids` list.
 
@@ -296,11 +321,23 @@ class RulesForRoom(object):
             for row in rows
         }
 
+        # If the event is a join event then it will be in current state evnts
+        # map but not in the DB, so we have to explicitly insert it.
+        if event.type == EventTypes.Member:
+            for event_id in member_event_ids.itervalues():
+                if event_id == event.event_id:
+                    members[event_id] = (event.state_key, event.membership)
+
+        if logger.isEnabledFor(logging.DEBUG):
+            logger.debug("Found members %r: %r", self.room_id, members.values())
+
         interested_in_user_ids = set(
             user_id for user_id, membership in members.itervalues()
             if membership == Membership.JOIN
         )
 
+        logger.debug("Joined: %r", interested_in_user_ids)
+
         if_users_with_pushers = yield self.store.get_if_users_have_pushers(
             interested_in_user_ids,
             on_invalidate=self.invalidate_all_cb,
@@ -310,10 +347,14 @@ class RulesForRoom(object):
             uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
         )
 
+        logger.debug("With pushers: %r", user_ids)
+
         users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
             self.room_id, on_invalidate=self.invalidate_all_cb,
         )
 
+        logger.debug("With receipts: %r", users_with_receipts)
+
         # any users with pushers must be ours: they have pushers
         for uid in users_with_receipts:
             if uid in interested_in_user_ids:
@@ -334,6 +375,7 @@ class RulesForRoom(object):
         # as it keeps a reference to self and will stop this instance from being
         # GC'd if it gets dropped from the rules_to_user cache. Instead use
         # `self.invalidate_all_cb`
+        logger.debug("Invalidating RulesForRoom for %r", self.room_id)
         self.sequence += 1
         self.state_group = object()
         self.member_map = {}
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c7afd11111..a69dda7b09 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,7 +21,6 @@ import logging
 from synapse.util.metrics import Measure
 from synapse.util.logcontext import LoggingContext
 
-from mailer import Mailer
 
 logger = logging.getLogger(__name__)
 
@@ -56,8 +55,10 @@ class EmailPusher(object):
     This shares quite a bit of code with httpusher: it would be good to
     factor out the common parts
     """
-    def __init__(self, hs, pusherdict):
+    def __init__(self, hs, pusherdict, mailer):
         self.hs = hs
+        self.mailer = mailer
+
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
         self.pusher_id = pusherdict['id']
@@ -73,16 +74,6 @@ class EmailPusher(object):
 
         self.processing = False
 
-        if self.hs.config.email_enable_notifs:
-            if 'data' in pusherdict and 'brand' in pusherdict['data']:
-                app_name = pusherdict['data']['brand']
-            else:
-                app_name = self.hs.config.email_app_name
-
-            self.mailer = Mailer(self.hs, app_name)
-        else:
-            self.mailer = None
-
     @defer.inlineCallbacks
     def on_started(self):
         if self.mailer is not None:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index f83aa7625c..b5cd9b426a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
 
 
 class Mailer(object):
-    def __init__(self, hs, app_name):
+    def __init__(self, hs, app_name, notif_template_html, notif_template_text):
         self.hs = hs
+        self.notif_template_html = notif_template_html
+        self.notif_template_text = notif_template_text
+
         self.store = self.hs.get_datastore()
         self.macaroon_gen = self.hs.get_macaroon_generator()
         self.state_handler = self.hs.get_state_handler()
-        loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
         self.app_name = app_name
+
         logger.info("Created Mailer for app_name %s" % app_name)
-        env = jinja2.Environment(loader=loader)
-        env.filters["format_ts"] = format_ts_filter
-        env.filters["mxc_to_http"] = self.mxc_to_http_filter
-        self.notif_template_html = env.get_template(
-            self.hs.config.email_notif_template_html
-        )
-        self.notif_template_text = env.get_template(
-            self.hs.config.email_notif_template_text
-        )
 
     @defer.inlineCallbacks
     def send_notification_mail(self, app_id, user_id, email_address,
@@ -481,28 +475,6 @@ class Mailer(object):
             urllib.urlencode(params),
         )
 
-    def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
-        if value[0:6] != "mxc://":
-            return ""
-
-        serverAndMediaId = value[6:]
-        fragment = None
-        if '#' in serverAndMediaId:
-            (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
-            fragment = "#" + fragment
-
-        params = {
-            "width": width,
-            "height": height,
-            "method": resize_method,
-        }
-        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
-            self.hs.config.public_baseurl,
-            serverAndMediaId,
-            urllib.urlencode(params),
-            fragment or "",
-        )
-
 
 def safe_markup(raw_html):
     return jinja2.Markup(bleach.linkify(bleach.clean(
@@ -543,3 +515,52 @@ def string_ordinal_total(s):
 
 def format_ts_filter(value, format):
     return time.strftime(format, time.localtime(value / 1000))
+
+
+def load_jinja2_templates(config):
+    """Load the jinja2 email templates from disk
+
+    Returns:
+        (notif_template_html, notif_template_text)
+    """
+    logger.info("loading jinja2")
+
+    loader = jinja2.FileSystemLoader(config.email_template_dir)
+    env = jinja2.Environment(loader=loader)
+    env.filters["format_ts"] = format_ts_filter
+    env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
+
+    notif_template_html = env.get_template(
+        config.email_notif_template_html
+    )
+    notif_template_text = env.get_template(
+        config.email_notif_template_text
+    )
+
+    return notif_template_html, notif_template_text
+
+
+def _create_mxc_to_http_filter(config):
+    def mxc_to_http_filter(value, width, height, resize_method="crop"):
+        if value[0:6] != "mxc://":
+            return ""
+
+        serverAndMediaId = value[6:]
+        fragment = None
+        if '#' in serverAndMediaId:
+            (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+            fragment = "#" + fragment
+
+        params = {
+            "width": width,
+            "height": height,
+            "method": resize_method,
+        }
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            config.public_baseurl,
+            serverAndMediaId,
+            urllib.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index de9c33b936..9385c80ce3 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
 # process works fine)
 try:
     from synapse.push.emailpusher import EmailPusher
+    from synapse.push.mailer import Mailer, load_jinja2_templates
 except:
     pass
 
 
-def create_pusher(hs, pusherdict):
-    logger.info("trying to create_pusher for %r", pusherdict)
+class PusherFactory(object):
+    def __init__(self, hs):
+        self.hs = hs
 
-    PUSHER_TYPES = {
-        "http": HttpPusher,
-    }
+        self.pusher_types = {
+            "http": HttpPusher,
+        }
 
-    logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
-    if hs.config.email_enable_notifs:
-        PUSHER_TYPES["email"] = EmailPusher
-        logger.info("defined email pusher type")
+        logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
+        if hs.config.email_enable_notifs:
+            self.mailers = {}  # app_name -> Mailer
 
-    if pusherdict['kind'] in PUSHER_TYPES:
-        logger.info("found pusher")
-        return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
+            templates = load_jinja2_templates(hs.config)
+            self.notif_template_html, self.notif_template_text = templates
+
+            self.pusher_types["email"] = self._create_email_pusher
+
+            logger.info("defined email pusher type")
+
+    def create_pusher(self, pusherdict):
+        logger.info("trying to create_pusher for %r", pusherdict)
+
+        if pusherdict['kind'] in self.pusher_types:
+            logger.info("found pusher")
+            return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
+
+    def _create_email_pusher(self, pusherdict):
+        app_name = self._brand_from_pusherdict
+        mailer = self.mailers.get(app_name)
+        if not mailer:
+            mailer = Mailer(
+                hs=self.hs,
+                app_name=app_name,
+                notif_template_html=self.notif_template_html,
+                notif_template_text=self.notif_template_text,
+            )
+            self.mailers[app_name] = mailer
+        return EmailPusher(self.hs, pusherdict, mailer)
+
+    def _app_name_from_pusherdict(self, pusherdict):
+        if 'data' in pusherdict and 'brand' in pusherdict['data']:
+            app_name = pusherdict['data']['brand']
+        else:
+            app_name = self.hs.config.email_app_name
+
+        return app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3837be523d..43cb6e9c01 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -16,7 +16,7 @@
 
 from twisted.internet import defer
 
-import pusher
+from .pusher import PusherFactory
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.util.async import run_on_reactor
 
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
 class PusherPool:
     def __init__(self, _hs):
         self.hs = _hs
+        self.pusher_factory = PusherFactory(_hs)
         self.start_pushers = _hs.config.start_pushers
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
@@ -48,7 +49,7 @@ class PusherPool:
         # will then get pulled out of the database,
         # recreated, added and started: this means we have only one
         # code path adding pushers.
-        pusher.create_pusher(self.hs, {
+        self.pusher_factory.create_pusher({
             "id": None,
             "user_name": user_id,
             "kind": kind,
@@ -186,7 +187,7 @@ class PusherPool:
         logger.info("Starting %d pushers", len(pushers))
         for pusherdict in pushers:
             try:
-                p = pusher.create_pusher(self.hs, pusherdict)
+                p = self.pusher_factory.create_pusher(pusherdict)
             except:
                 logger.exception("Couldn't start a pusher: caught Exception")
                 continue
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 48dcbafeef..cbdff86596 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -404,6 +404,7 @@ class CacheDescriptor(_CacheDescriptorBase):
 
         wrapped.invalidate_all = cache.invalidate_all
         wrapped.cache = cache
+        wrapped.num_args = self.num_args
 
         obj.__dict__[self.orig.__name__] = wrapped
 
@@ -451,8 +452,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
             )
 
     def __get__(self, obj, objtype=None):
-
-        cache = getattr(obj, self.cached_method_name).cache
+        cached_method = getattr(obj, self.cached_method_name)
+        cache = cached_method.cache
+        num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
@@ -469,12 +471,23 @@ class CacheListDescriptor(_CacheDescriptorBase):
             results = {}
             cached_defers = {}
             missing = []
-            for arg in list_args:
+
+            # If the cache takes a single arg then that is used as the key,
+            # otherwise a tuple is used.
+            if num_args == 1:
+                def cache_get(arg):
+                    return cache.get(arg, callback=invalidate_callback)
+            else:
                 key = list(keyargs)
-                key[self.list_pos] = arg
 
+                def cache_get(arg):
+                    key[self.list_pos] = arg
+                    return cache.get(tuple(key), callback=invalidate_callback)
+
+            for arg in list_args:
                 try:
-                    res = cache.get(tuple(key), callback=invalidate_callback)
+                    res = cache_get(arg)
+
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
@@ -505,17 +518,28 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
                     observer = ObservableDeferred(observer)
 
-                    key = list(keyargs)
-                    key[self.list_pos] = arg
-                    cache.set(
-                        tuple(key), observer,
-                        callback=invalidate_callback
-                    )
-
-                    def invalidate(f, key):
-                        cache.invalidate(key)
-                        return f
-                    observer.addErrback(invalidate, tuple(key))
+                    if num_args == 1:
+                        cache.set(
+                            arg, observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, arg)
+                    else:
+                        key = list(keyargs)
+                        key[self.list_pos] = arg
+                        cache.set(
+                            tuple(key), observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, tuple(key))
 
                     res = observer.observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)
diff --git a/tests/utils.py b/tests/utils.py
index d3d6c8021d..4f7e32b3ab 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -55,6 +55,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.password_providers = []
         config.worker_replication_url = ""
         config.worker_app = None
+        config.email_enable_notifs = False
 
     config.use_frozen_dicts = True
     config.database_config = {"name": "sqlite3"}