summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/push/emailpusher.py15
-rw-r--r--synapse/push/mailer.py87
-rw-r--r--synapse/push/pusher.py56
-rw-r--r--synapse/push/pusherpool.py7
-rw-r--r--tests/utils.py1
5 files changed, 106 insertions, 60 deletions
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c7afd11111..a69dda7b09 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,7 +21,6 @@ import logging
 from synapse.util.metrics import Measure
 from synapse.util.logcontext import LoggingContext
 
-from mailer import Mailer
 
 logger = logging.getLogger(__name__)
 
@@ -56,8 +55,10 @@ class EmailPusher(object):
     This shares quite a bit of code with httpusher: it would be good to
     factor out the common parts
     """
-    def __init__(self, hs, pusherdict):
+    def __init__(self, hs, pusherdict, mailer):
         self.hs = hs
+        self.mailer = mailer
+
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
         self.pusher_id = pusherdict['id']
@@ -73,16 +74,6 @@ class EmailPusher(object):
 
         self.processing = False
 
-        if self.hs.config.email_enable_notifs:
-            if 'data' in pusherdict and 'brand' in pusherdict['data']:
-                app_name = pusherdict['data']['brand']
-            else:
-                app_name = self.hs.config.email_app_name
-
-            self.mailer = Mailer(self.hs, app_name)
-        else:
-            self.mailer = None
-
     @defer.inlineCallbacks
     def on_started(self):
         if self.mailer is not None:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index f83aa7625c..b5cd9b426a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
 
 
 class Mailer(object):
-    def __init__(self, hs, app_name):
+    def __init__(self, hs, app_name, notif_template_html, notif_template_text):
         self.hs = hs
+        self.notif_template_html = notif_template_html
+        self.notif_template_text = notif_template_text
+
         self.store = self.hs.get_datastore()
         self.macaroon_gen = self.hs.get_macaroon_generator()
         self.state_handler = self.hs.get_state_handler()
-        loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
         self.app_name = app_name
+
         logger.info("Created Mailer for app_name %s" % app_name)
-        env = jinja2.Environment(loader=loader)
-        env.filters["format_ts"] = format_ts_filter
-        env.filters["mxc_to_http"] = self.mxc_to_http_filter
-        self.notif_template_html = env.get_template(
-            self.hs.config.email_notif_template_html
-        )
-        self.notif_template_text = env.get_template(
-            self.hs.config.email_notif_template_text
-        )
 
     @defer.inlineCallbacks
     def send_notification_mail(self, app_id, user_id, email_address,
@@ -481,28 +475,6 @@ class Mailer(object):
             urllib.urlencode(params),
         )
 
-    def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
-        if value[0:6] != "mxc://":
-            return ""
-
-        serverAndMediaId = value[6:]
-        fragment = None
-        if '#' in serverAndMediaId:
-            (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
-            fragment = "#" + fragment
-
-        params = {
-            "width": width,
-            "height": height,
-            "method": resize_method,
-        }
-        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
-            self.hs.config.public_baseurl,
-            serverAndMediaId,
-            urllib.urlencode(params),
-            fragment or "",
-        )
-
 
 def safe_markup(raw_html):
     return jinja2.Markup(bleach.linkify(bleach.clean(
@@ -543,3 +515,52 @@ def string_ordinal_total(s):
 
 def format_ts_filter(value, format):
     return time.strftime(format, time.localtime(value / 1000))
+
+
+def load_jinja2_templates(config):
+    """Load the jinja2 email templates from disk
+
+    Returns:
+        (notif_template_html, notif_template_text)
+    """
+    logger.info("loading jinja2")
+
+    loader = jinja2.FileSystemLoader(config.email_template_dir)
+    env = jinja2.Environment(loader=loader)
+    env.filters["format_ts"] = format_ts_filter
+    env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
+
+    notif_template_html = env.get_template(
+        config.email_notif_template_html
+    )
+    notif_template_text = env.get_template(
+        config.email_notif_template_text
+    )
+
+    return notif_template_html, notif_template_text
+
+
+def _create_mxc_to_http_filter(config):
+    def mxc_to_http_filter(value, width, height, resize_method="crop"):
+        if value[0:6] != "mxc://":
+            return ""
+
+        serverAndMediaId = value[6:]
+        fragment = None
+        if '#' in serverAndMediaId:
+            (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+            fragment = "#" + fragment
+
+        params = {
+            "width": width,
+            "height": height,
+            "method": resize_method,
+        }
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            config.public_baseurl,
+            serverAndMediaId,
+            urllib.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index de9c33b936..9385c80ce3 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
 # process works fine)
 try:
     from synapse.push.emailpusher import EmailPusher
+    from synapse.push.mailer import Mailer, load_jinja2_templates
 except:
     pass
 
 
-def create_pusher(hs, pusherdict):
-    logger.info("trying to create_pusher for %r", pusherdict)
+class PusherFactory(object):
+    def __init__(self, hs):
+        self.hs = hs
 
-    PUSHER_TYPES = {
-        "http": HttpPusher,
-    }
+        self.pusher_types = {
+            "http": HttpPusher,
+        }
 
-    logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
-    if hs.config.email_enable_notifs:
-        PUSHER_TYPES["email"] = EmailPusher
-        logger.info("defined email pusher type")
+        logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
+        if hs.config.email_enable_notifs:
+            self.mailers = {}  # app_name -> Mailer
 
-    if pusherdict['kind'] in PUSHER_TYPES:
-        logger.info("found pusher")
-        return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
+            templates = load_jinja2_templates(hs.config)
+            self.notif_template_html, self.notif_template_text = templates
+
+            self.pusher_types["email"] = self._create_email_pusher
+
+            logger.info("defined email pusher type")
+
+    def create_pusher(self, pusherdict):
+        logger.info("trying to create_pusher for %r", pusherdict)
+
+        if pusherdict['kind'] in self.pusher_types:
+            logger.info("found pusher")
+            return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
+
+    def _create_email_pusher(self, pusherdict):
+        app_name = self._brand_from_pusherdict
+        mailer = self.mailers.get(app_name)
+        if not mailer:
+            mailer = Mailer(
+                hs=self.hs,
+                app_name=app_name,
+                notif_template_html=self.notif_template_html,
+                notif_template_text=self.notif_template_text,
+            )
+            self.mailers[app_name] = mailer
+        return EmailPusher(self.hs, pusherdict, mailer)
+
+    def _app_name_from_pusherdict(self, pusherdict):
+        if 'data' in pusherdict and 'brand' in pusherdict['data']:
+            app_name = pusherdict['data']['brand']
+        else:
+            app_name = self.hs.config.email_app_name
+
+        return app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3837be523d..43cb6e9c01 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -16,7 +16,7 @@
 
 from twisted.internet import defer
 
-import pusher
+from .pusher import PusherFactory
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.util.async import run_on_reactor
 
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
 class PusherPool:
     def __init__(self, _hs):
         self.hs = _hs
+        self.pusher_factory = PusherFactory(_hs)
         self.start_pushers = _hs.config.start_pushers
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
@@ -48,7 +49,7 @@ class PusherPool:
         # will then get pulled out of the database,
         # recreated, added and started: this means we have only one
         # code path adding pushers.
-        pusher.create_pusher(self.hs, {
+        self.pusher_factory.create_pusher({
             "id": None,
             "user_name": user_id,
             "kind": kind,
@@ -186,7 +187,7 @@ class PusherPool:
         logger.info("Starting %d pushers", len(pushers))
         for pusherdict in pushers:
             try:
-                p = pusher.create_pusher(self.hs, pusherdict)
+                p = self.pusher_factory.create_pusher(pusherdict)
             except:
                 logger.exception("Couldn't start a pusher: caught Exception")
                 continue
diff --git a/tests/utils.py b/tests/utils.py
index d3d6c8021d..4f7e32b3ab 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -55,6 +55,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.password_providers = []
         config.worker_replication_url = ""
         config.worker_app = None
+        config.email_enable_notifs = False
 
     config.use_frozen_dicts = True
     config.database_config = {"name": "sqlite3"}