diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index f506b2a695..7856353002 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -49,7 +49,7 @@ class Clock(object):
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
- defer.returnValue(res)
+ return res
def time(self):
"""Returns the current system time in seconds since epoch."""
@@ -59,7 +59,7 @@ class Clock(object):
"""Returns the current system time in miliseconds since epoch."""
return int(self.time() * 1000)
- def looping_call(self, f, msec):
+ def looping_call(self, f, msec, *args, **kwargs):
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
@@ -70,8 +70,10 @@ class Clock(object):
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
+ *args: Postional arguments to pass to function.
+ **kwargs: Key arguments to pass to function.
"""
- call = task.LoopingCall(f)
+ call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 58a6b8764f..f1c46836b1 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -366,7 +366,7 @@ class ReadWriteLock(object):
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
@defer.inlineCallbacks
def write(self, key):
@@ -396,7 +396,7 @@ class ReadWriteLock(object):
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
def _cancelled_to_timed_out_error(value, timeout):
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8271229015..b50e3503f0 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -51,7 +52,19 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
-def register_cache(cache_type, cache_name, cache):
+def register_cache(cache_type, cache_name, cache, collect_callback=None):
+ """Register a cache object for metric collection.
+
+ Args:
+ cache_type (str):
+ cache_name (str): name of the cache
+ cache (object): cache itself
+ collect_callback (callable|None): if not None, a function which is called during
+ metric collection to update additional metrics.
+
+ Returns:
+ CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+ """
# Check if the metric is already registered. Unregister it, if so.
# This usually happens during tests, as at runtime these caches are
@@ -90,6 +103,8 @@ def register_cache(cache_type, cache_name, cache):
cache_hits.labels(cache_name).set(self.hits)
cache_evicted.labels(cache_name).set(self.evicted_size)
cache_total.labels(cache_name).set(self.hits + self.misses)
+ if collect_callback:
+ collect_callback()
except Exception as e:
logger.warn("Error calculating metrics for %s: %s", cache_name, e)
raise
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 675db2f448..43f66ec4be 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,8 +19,9 @@ import logging
import threading
from collections import namedtuple
-import six
-from six import itervalues, string_types
+from six import itervalues
+
+from prometheus_client import Gauge
from twisted.internet import defer
@@ -30,13 +31,18 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
from . import register_cache
logger = logging.getLogger(__name__)
+cache_pending_metric = Gauge(
+ "synapse_util_caches_cache_pending",
+ "Number of lookups currently pending for this cache",
+ ["name"],
+)
+
_CacheSentinel = object()
@@ -82,11 +88,19 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
- self.metrics = register_cache("cache", name, self.cache)
+ self.metrics = register_cache(
+ "cache",
+ name,
+ self.cache,
+ collect_callback=self._metrics_collection_callback,
+ )
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
+ def _metrics_collection_callback(self):
+ cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -108,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either a Deferred or the raw result
+ Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -132,9 +146,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
callbacks = [callback] if callback else []
self.check_thread()
- entry = CacheEntry(deferred=value, callbacks=callbacks)
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = defer.maybeDeferred(observable.observe)
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@@ -142,20 +161,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
- def shuffle(result):
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@@ -163,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
- return result
- entry.deferred.addCallback(shuffle)
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+ return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@@ -289,7 +326,7 @@ class CacheDescriptor(_CacheDescriptorBase):
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
- defer.returnValue(r1 + r2)
+ return r1 + r2
Args:
num_args (int): number of positional arguments (excluding ``self`` and
@@ -398,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string on py2, try to convert to ascii
- # to save a bit of space in large caches. Py3 does this
- # internally automatically.
- if six.PY2 and isinstance(cache_key, string_types):
- cache_key = to_ascii(cache_key)
-
- result_d = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, result_d, callback=invalidate_callback)
+ result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
- if isinstance(observer, defer.Deferred):
- return make_deferred_yieldable(observer)
- else:
- return observer
+ return make_deferred_yieldable(observer)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -527,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
- # we need an observable deferred for each entry in the list,
+ # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@@ -535,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
- observable = ObservableDeferred(deferred)
- cache.set(key, observable, callback=invalidate_callback)
+ cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index d6908e169d..82d3eefe0e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -121,7 +121,7 @@ class ResponseCache(object):
@defer.inlineCallbacks
def handle_request(request):
# etc
- defer.returnValue(result)
+ return result
result = yield response_cache.wrap(
key,
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c30b6de19c..0910930c21 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -67,7 +67,7 @@ def measure_func(name):
def measured_func(self, *args, **kwargs):
with Measure(self.clock, name):
r = yield func(self, *args, **kwargs)
- defer.returnValue(r)
+ return r
return measured_func
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d8d0ceae51..0862b5ca5a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -95,15 +95,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# maximum backoff even though it might only have been down briefly
backoff_on_failure = not ignore_backoff
- defer.returnValue(
- RetryDestinationLimiter(
- destination,
- clock,
- store,
- retry_interval,
- backoff_on_failure=backoff_on_failure,
- **kwargs
- )
+ return RetryDestinationLimiter(
+ destination,
+ clock,
+ store,
+ retry_interval,
+ backoff_on_failure=backoff_on_failure,
+ **kwargs
)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 982c6d81ca..6a2464cab3 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +15,15 @@
# limitations under the License.
import random
+import re
import string
import six
from six import PY2, PY3
from six.moves import range
+from synapse.api.errors import Codes, SynapseError
+
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# random_string and random_string_with_symbols are used for a range of things,
@@ -27,6 +31,8 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# we get cryptographically-secure randoms.
rand = random.SystemRandom()
+client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$")
+
def random_string(length):
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
@@ -109,3 +115,11 @@ def exception_to_unicode(e):
return msg.decode("utf-8", errors="replace")
else:
return msg
+
+
+def assert_valid_client_secret(client_secret):
+ """Validate that a given string matches the client_secret regex defined by the spec"""
+ if client_secret_regex.match(client_secret) is None:
+ raise SynapseError(
+ 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
+ )
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,8 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index a4d9a462f7..fa404b9d75 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -22,6 +22,23 @@ logger = logging.getLogger(__name__)
def get_version_string(module):
+ """Given a module calculate a git-aware version string for it.
+
+ If called on a module not in a git checkout will return `__verison__`.
+
+ Args:
+ module (module)
+
+ Returns:
+ str
+ """
+
+ cached_version = getattr(module, "_synapse_version_string_cache", None)
+ if cached_version:
+ return cached_version
+
+ version_string = module.__version__
+
try:
null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
@@ -80,8 +97,10 @@ def get_version_string(module):
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- return "%s (%s)" % (module.__version__, git_version)
+ version_string = "%s (%s)" % (module.__version__, git_version)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__
+ module._synapse_version_string_cache = version_string
+
+ return version_string
|