diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 092601f530..0c4504d5d8 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -1,6 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector
-# Copyright 2021 The Matrix.org Foundation C.I.C.
+# Copyright 2021-22 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,11 +20,22 @@ import hashlib
import hmac
import logging
import sys
-from typing import Callable, Optional
+from typing import Any, Callable, Dict, Optional
import requests
import yaml
+_CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
+Conflicting options 'registration_shared_secret' and 'registration_shared_secret_path'
+are both defined in config file.
+"""
+
+_NO_SHARED_SECRET_OPTS_ERROR = """\
+No 'registration_shared_secret' or 'registration_shared_secret_path' defined in config.
+"""
+
+_DEFAULT_SERVER_URL = "http://localhost:8008"
+
def request_registration(
user: str,
@@ -203,31 +214,104 @@ def main() -> None:
parser.add_argument(
"server_url",
- default="https://localhost:8448",
nargs="?",
- help="URL to use to talk to the homeserver. Defaults to "
- " 'https://localhost:8448'.",
+ help="URL to use to talk to the homeserver. By default, tries to find a "
+ "suitable URL from the configuration file. Otherwise, defaults to "
+ f"'{_DEFAULT_SERVER_URL}'.",
)
args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
- secret = config.get("registration_shared_secret", None)
+
+ if args.shared_secret:
+ secret = args.shared_secret
+ else:
+ # argparse should check that we have either config or shared secret
+ assert config
+
+ secret = config.get("registration_shared_secret")
+ secret_file = config.get("registration_shared_secret_path")
+ if secret_file:
+ if secret:
+ print(_CONFLICTING_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
+ sys.exit(1)
+ secret = _read_file(secret_file, "registration_shared_secret_path").strip()
if not secret:
- print("No 'registration_shared_secret' defined in config.")
+ print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
sys.exit(1)
+
+ if args.server_url:
+ server_url = args.server_url
+ elif config:
+ server_url = _find_client_listener(config)
+ if not server_url:
+ server_url = _DEFAULT_SERVER_URL
+ print(
+ "Unable to find a suitable HTTP listener in the configuration file. "
+ f"Trying {server_url} as a last resort.",
+ file=sys.stderr,
+ )
else:
- secret = args.shared_secret
+ server_url = _DEFAULT_SERVER_URL
+ print(
+ f"No server url or configuration file given. Defaulting to {server_url}.",
+ file=sys.stderr,
+ )
admin = None
if args.admin or args.no_admin:
admin = args.admin
register_new_user(
- args.user, args.password, args.server_url, secret, admin, args.user_type
+ args.user, args.password, server_url, secret, admin, args.user_type
)
+def _read_file(file_path: Any, config_path: str) -> str:
+ """Check the given file exists, and read it into a string
+
+ If it does not, exit with an error indicating the problem
+
+ Args:
+ file_path: the file to be read
+ config_path: where in the configuration file_path came from, so that a useful
+ error can be emitted if it does not exist.
+ Returns:
+ content of the file.
+ """
+ if not isinstance(file_path, str):
+ print(f"{config_path} setting is not a string", file=sys.stderr)
+ sys.exit(1)
+
+ try:
+ with open(file_path) as file_stream:
+ return file_stream.read()
+ except OSError as e:
+ print(f"Error accessing file {file_path}: {e}", file=sys.stderr)
+ sys.exit(1)
+
+
+def _find_client_listener(config: Dict[str, Any]) -> Optional[str]:
+ # try to find a listener in the config. Returns a host:port pair
+ for listener in config.get("listeners", []):
+ if listener.get("type") != "http" or listener.get("tls", False):
+ continue
+
+ if not any(
+ name == "client"
+ for resource in listener.get("resources", [])
+ for name in resource.get("names", [])
+ ):
+ continue
+
+ # TODO: consider bind_addresses
+ return f"http://localhost:{listener['port']}"
+
+ # no suitable listeners?
+ return None
+
+
if __name__ == "__main__":
main()
|