summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py142
1 files changed, 136 insertions, 6 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 30d1050a91..1417487427 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,12 +18,16 @@
 import argparse
 import errno
 import os
+import time
+import urllib.parse
 from collections import OrderedDict
+from hashlib import sha256
 from textwrap import dedent
-from typing import Any, MutableMapping, Optional
-
-from six import integer_types
+from typing import Any, Callable, List, MutableMapping, Optional
 
+import attr
+import jinja2
+import pkg_resources
 import yaml
 
 
@@ -100,6 +104,11 @@ class Config(object):
     def __init__(self, root_config=None):
         self.root = root_config
 
+        # Get the path to the default Synapse template directory
+        self.default_template_dir = pkg_resources.resource_filename(
+            "synapse", "res/templates"
+        )
+
     def __getattr__(self, item: str) -> Any:
         """
         Try and fetch a configuration option that does not exist on this class.
@@ -117,7 +126,7 @@ class Config(object):
 
     @staticmethod
     def parse_size(value):
-        if isinstance(value, integer_types):
+        if isinstance(value, int):
             return value
         sizes = {"K": 1024, "M": 1024 * 1024}
         size = 1
@@ -129,7 +138,7 @@ class Config(object):
 
     @staticmethod
     def parse_duration(value):
-        if isinstance(value, integer_types):
+        if isinstance(value, int):
             return value
         second = 1000
         minute = 60 * second
@@ -184,6 +193,95 @@ class Config(object):
         with open(file_path) as file_stream:
             return file_stream.read()
 
+    def read_templates(
+        self, filenames: List[str], custom_template_directory: Optional[str] = None,
+    ) -> List[jinja2.Template]:
+        """Load a list of template files from disk using the given variables.
+
+        This function will attempt to load the given templates from the default Synapse
+        template directory. If `custom_template_directory` is supplied, that directory
+        is tried first.
+
+        Files read are treated as Jinja templates. These templates are not rendered yet.
+
+        Args:
+            filenames: A list of template filenames to read.
+
+            custom_template_directory: A directory to try to look for the templates
+                before using the default Synapse template directory instead.
+
+        Raises:
+            ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+        Returns:
+            A list of jinja2 templates.
+        """
+        templates = []
+        search_directories = [self.default_template_dir]
+
+        # The loader will first look in the custom template directory (if specified) for the
+        # given filename. If it doesn't find it, it will use the default template dir instead
+        if custom_template_directory:
+            # Check that the given template directory exists
+            if not self.path_exists(custom_template_directory):
+                raise ConfigError(
+                    "Configured template directory does not exist: %s"
+                    % (custom_template_directory,)
+                )
+
+            # Search the custom template directory as well
+            search_directories.insert(0, custom_template_directory)
+
+        loader = jinja2.FileSystemLoader(search_directories)
+        env = jinja2.Environment(loader=loader, autoescape=True)
+
+        # Update the environment with our custom filters
+        env.filters.update(
+            {
+                "format_ts": _format_ts_filter,
+                "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+            }
+        )
+
+        for filename in filenames:
+            # Load the template
+            template = env.get_template(filename)
+            templates.append(template)
+
+        return templates
+
+
+def _format_ts_filter(value: int, format: str):
+    return time.strftime(format, time.localtime(value / 1000))
+
+
+def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
+    """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+    Args:
+        public_baseurl: The public, accessible base URL of the homeserver
+    """
+
+    def mxc_to_http_filter(value, width, height, resize_method="crop"):
+        if value[0:6] != "mxc://":
+            return ""
+
+        server_and_media_id = value[6:]
+        fragment = None
+        if "#" in server_and_media_id:
+            server_and_media_id, fragment = server_and_media_id.split("#", 1)
+            fragment = "#" + fragment
+
+        params = {"width": width, "height": height, "method": resize_method}
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            public_baseurl,
+            server_and_media_id,
+            urllib.parse.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
+
 
 class RootConfig(object):
     """
@@ -719,4 +817,36 @@ def find_config_files(search_paths):
     return config_files
 
 
-__all__ = ["Config", "RootConfig"]
+@attr.s
+class ShardedWorkerHandlingConfig:
+    """Algorithm for choosing which instance is responsible for handling some
+    sharded work.
+
+    For example, the federation senders use this to determine which instances
+    handles sending stuff to a given destination (which is used as the `key`
+    below).
+    """
+
+    instances = attr.ib(type=List[str])
+
+    def should_handle(self, instance_name: str, key: str) -> bool:
+        """Whether this instance is responsible for handling the given key.
+        """
+
+        # If multiple instances are not defined we always return true.
+        if not self.instances or len(self.instances) == 1:
+            return True
+
+        # We shard by taking the hash, modulo it by the number of instances and
+        # then checking whether this instance matches the instance at that
+        # index.
+        #
+        # (Technically this introduces some bias and is not entirely uniform,
+        # but since the hash is so large the bias is ridiculously small).
+        dest_hash = sha256(key.encode("utf8")).digest()
+        dest_int = int.from_bytes(dest_hash, byteorder="little")
+        remainder = dest_int % (len(self.instances))
+        return self.instances[remainder] == instance_name
+
+
+__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]