diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5ac2530a6a..0e8da27f53 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -438,7 +438,7 @@ class CacheDescriptor(_CacheDescriptorBase):
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
- observer = cached_result_d
+ observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(
@@ -482,9 +482,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
- Once wrapped, the function returns either a Deferred which resolves to
- the list of results, or (if all results were cached), just the list of
- results.
+ Once wrapped, the function returns a Deferred which resolves to the list
+ of results.
"""
def __init__(
@@ -618,7 +617,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
)
return make_deferred_yieldable(d)
else:
- return results
+ return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 982c6d81ca..6a2464cab3 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +15,15 @@
# limitations under the License.
import random
+import re
import string
import six
from six import PY2, PY3
from six.moves import range
+from synapse.api.errors import Codes, SynapseError
+
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# random_string and random_string_with_symbols are used for a range of things,
@@ -27,6 +31,8 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# we get cryptographically-secure randoms.
rand = random.SystemRandom()
+client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$")
+
def random_string(length):
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
@@ -109,3 +115,11 @@ def exception_to_unicode(e):
return msg.decode("utf-8", errors="replace")
else:
return msg
+
+
+def assert_valid_client_secret(client_secret):
+ """Validate that a given string matches the client_secret regex defined by the spec"""
+ if client_secret_regex.match(client_secret) is None:
+ raise SynapseError(
+ 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
+ )
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,8 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
|