summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-09-19 20:29:11 +0100
committerRichard van der Hoff <richard@matrix.org>2019-09-19 20:32:14 +0100
commitb74606ea2262a717193f08bb6876459c1ee2d97d (patch)
tree20a2d0e2bd3a2ff67f7034d91aed8cba0f068601 /synapse
parent1.3.1 (diff)
downloadsynapse-b74606ea2262a717193f08bb6876459c1ee2d97d.tar.xz
Fix a bug with saml attribute maps.
Fixes a bug where the default attribute maps were prioritised over
user-specified ones, resulting in incorrect mappings.

The problem is that if you call SPConfig.load() multiple times, it adds new
attribute mappers to a list. So by calling it with the default config first,
and then the user-specified config, we would always get the default mappers
before the user-specified mappers.

To solve this, let's merge the config dicts first, and then pass them to
SPConfig.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/saml2_config.py34
-rw-r--r--synapse/util/module_loader.py20
2 files changed, 47 insertions, 7 deletions
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 6a8161547a..14539fdb2a 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,11 +13,29 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.module_loader import load_python_module
 
 from ._base import Config, ConfigError
 
 
+def _dict_merge(merge_dict, into_dct):
+    for k, v in merge_dict.items():
+        if k not in into_dct:
+            into_dct[k] = v
+            continue
+
+        current_val = into_dct[k]
+
+        if isinstance(v, dict) and isinstance(current_val, dict):
+            _dict_merge(v, current_val)
+            continue
+
+        # otherwise we just overwrite
+        into_dct[k] = v
+
+
 class SAML2Config(Config):
     def read_config(self, config, **kwargs):
         self.saml2_enabled = False
@@ -33,15 +52,18 @@ class SAML2Config(Config):
 
         self.saml2_enabled = True
 
-        import saml2.config
-
-        self.saml2_sp_config = saml2.config.SPConfig()
-        self.saml2_sp_config.load(self._default_saml_config_dict())
-        self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
+        saml2_config_dict = self._default_saml_config_dict()
+        _dict_merge(saml2_config.get("sp_config", {}), saml2_config_dict)
 
         config_path = saml2_config.get("config_path", None)
         if config_path is not None:
-            self.saml2_sp_config.load_file(config_path)
+            mod = load_python_module(config_path)
+            _dict_merge(mod.CONFIG, saml2_config_dict)
+
+        import saml2.config
+
+        self.saml2_sp_config = saml2.config.SPConfig()
+        self.saml2_sp_config.load(saml2_config_dict)
 
         # session lifetime: in milliseconds
         self.saml2_session_lifetime = self.parse_duration(
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 522acd5aa8..7ff7eb1e4d 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 
 import importlib
+import importlib.util
 
 from synapse.config._base import ConfigError
 
 
 def load_module(provider):
-    """ Loads a module with its config
+    """ Loads a synapse module with its config
     Take a dict with keys 'module' (the module name) and 'config'
     (the config dict).
 
@@ -38,3 +39,20 @@ def load_module(provider):
         raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
 
     return provider_class, provider_config
+
+
+def load_python_module(location: str):
+    """Load a python module, and return a reference to its global namespace
+
+    Args:
+        location (str): path to the module
+
+    Returns:
+        python module object
+    """
+    spec = importlib.util.spec_from_file_location(location, location)
+    if spec is None:
+        raise Exception("Unable to load module at %s" % (location,))
+    mod = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(mod)
+    return mod