diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 804dbca443..7659eaeb42 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -138,7 +138,7 @@ def concurrently_execute(func, args, limit):
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred.
+ func (func): Function to execute, should return a deferred or coroutine.
args (list): List of arguments to pass to func, each invocation of func
gets a signle argument.
limit (int): Maximum number of conccurent executions.
@@ -148,11 +148,10 @@ def concurrently_execute(func, args, limit):
"""
it = iter(args)
- @defer.inlineCallbacks
- def _concurrently_execute_inner():
+ async def _concurrently_execute_inner():
try:
while True:
- yield func(next(it))
+ await maybe_awaitable(func(next(it)))
except StopIteration:
pass
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5ac2530a6a..0e8da27f53 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -438,7 +438,7 @@ class CacheDescriptor(_CacheDescriptorBase):
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
- observer = cached_result_d
+ observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(
@@ -482,9 +482,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
- Once wrapped, the function returns either a Deferred which resolves to
- the list of results, or (if all results were cached), just the list of
- results.
+ Once wrapped, the function returns a Deferred which resolves to the list
+ of results.
"""
def __init__(
@@ -618,7 +617,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
return make_deferred_yieldable(d)
else:
- return results
+ return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 982c6d81ca..6a2464cab3 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +15,15 @@
# limitations under the License.
import random
+import re
import string
import six
from six import PY2, PY3
from six.moves import range
+from synapse.api.errors import Codes, SynapseError
+
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# random_string and random_string_with_symbols are used for a range of things,
@@ -27,6 +31,8 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# we get cryptographically-secure randoms.
rand = random.SystemRandom()
+client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$")
+
def random_string(length):
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
@@ -109,3 +115,11 @@ def exception_to_unicode(e):
return msg.decode("utf-8", errors="replace")
else:
return msg
+
+
+def assert_valid_client_secret(client_secret):
+ """Validate that a given string matches the client_secret regex defined by the spec"""
+ if client_secret_regex.match(client_secret) is None:
+ raise SynapseError(
+ 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
+ )
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,8 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index fa404b9d75..ab7d03af3a 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -42,6 +42,7 @@ def get_version_string(module):
try:
null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
+
try:
git_branch = (
subprocess.check_output(
@@ -51,7 +52,8 @@ def get_version_string(module):
.decode("ascii")
)
git_branch = "b=" + git_branch
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ # FileNotFoundError can arise when git is not installed
git_branch = ""
try:
@@ -63,7 +65,7 @@ def get_version_string(module):
.decode("ascii")
)
git_tag = "t=" + git_tag
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_tag = ""
try:
@@ -74,7 +76,7 @@ def get_version_string(module):
.strip()
.decode("ascii")
)
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_commit = ""
try:
@@ -89,7 +91,7 @@ def get_version_string(module):
)
git_dirty = "dirty" if is_dirty else ""
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
|