Fix a bug with saml attribute maps.
Fixes a bug where the default attribute maps were prioritised over
user-specified ones, resulting in incorrect mappings.
The problem is that if you call SPConfig.load() multiple times, it adds new
attribute mappers to a list. So by calling it with the default config first,
and then the user-specified config, we would always get the default mappers
before the user-specified mappers.
To solve this, let's merge the config dicts first, and then pass them to
SPConfig.
1 files changed, 28 insertions, 6 deletions
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 6a8161547a..14539fdb2a 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +13,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.module_loader import load_python_module
from ._base import Config, ConfigError
+def _dict_merge(merge_dict, into_dct):
+ for k, v in merge_dict.items():
+ if k not in into_dct:
+ into_dct[k] = v
+ continue
+
+ current_val = into_dct[k]
+
+ if isinstance(v, dict) and isinstance(current_val, dict):
+ _dict_merge(v, current_val)
+ continue
+
+ # otherwise we just overwrite
+ into_dct[k] = v
+
+
class SAML2Config(Config):
def read_config(self, config, **kwargs):
self.saml2_enabled = False
@@ -33,15 +52,18 @@ class SAML2Config(Config):
self.saml2_enabled = True
- import saml2.config
-
- self.saml2_sp_config = saml2.config.SPConfig()
- self.saml2_sp_config.load(self._default_saml_config_dict())
- self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
+ saml2_config_dict = self._default_saml_config_dict()
+ _dict_merge(saml2_config.get("sp_config", {}), saml2_config_dict)
config_path = saml2_config.get("config_path", None)
if config_path is not None:
- self.saml2_sp_config.load_file(config_path)
+ mod = load_python_module(config_path)
+ _dict_merge(mod.CONFIG, saml2_config_dict)
+
+ import saml2.config
+
+ self.saml2_sp_config = saml2.config.SPConfig()
+ self.saml2_sp_config.load(saml2_config_dict)
# session lifetime: in milliseconds
self.saml2_session_lifetime = self.parse_duration(
|