diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0d6f48e2d8..c84b23ff46 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -102,6 +102,15 @@ class ObservableDeferred(object):
def observers(self):
return self._observers
+ def has_called(self):
+ return self._result is not None
+
+ def has_succeeded(self):
+ return self._result is not None and self._result[0] is True
+
+ def get_result(self):
+ return self._result[1]
+
def __getattr__(self, name):
return getattr(self._deferred, name)
@@ -185,3 +194,85 @@ class Linearizer(object):
self.key_to_defer.pop(key, None)
defer.returnValue(_ctx_manager())
+
+
+class ReadWriteLock(object):
+ """A deferred style read write lock.
+
+ Example:
+
+ with (yield read_write_lock.read("test_key")):
+ # do some work
+ """
+
+ # IMPLEMENTATION NOTES
+ #
+ # We track the most recent queued reader and writer deferreds (which get
+ # resolved when they release the lock).
+ #
+ # Read: We know its safe to acquire a read lock when the latest writer has
+ # been resolved. The new reader is appeneded to the list of latest readers.
+ #
+ # Write: We know its safe to acquire the write lock when both the latest
+ # writers and readers have been resolved. The new writer replaces the latest
+ # writer.
+
+ def __init__(self):
+ # Latest readers queued
+ self.key_to_current_readers = {}
+
+ # Latest writer queued
+ self.key_to_current_writer = {}
+
+ @defer.inlineCallbacks
+ def read(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.setdefault(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ curr_readers.add(new_defer)
+
+ # We wait for the latest writer to finish writing. We can safely ignore
+ # any existing readers... as they're readers.
+ yield curr_writer
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ self.key_to_current_readers.get(key, set()).discard(new_defer)
+
+ defer.returnValue(_ctx_manager())
+
+ @defer.inlineCallbacks
+ def write(self, key):
+ new_defer = defer.Deferred()
+
+ curr_readers = self.key_to_current_readers.get(key, set())
+ curr_writer = self.key_to_current_writer.get(key, None)
+
+ # We wait on all latest readers and writer.
+ to_wait_on = list(curr_readers)
+ if curr_writer:
+ to_wait_on.append(curr_writer)
+
+ # We can clear the list of current readers since the new writer waits
+ # for them to finish.
+ curr_readers.clear()
+ self.key_to_current_writer[key] = new_defer
+
+ yield defer.gatherResults(to_wait_on)
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ if self.key_to_current_writer[key] == new_defer:
+ self.key_to_current_writer.pop(key)
+
+ defer.returnValue(_ctx_manager())
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index d53569ca49..ebd715c5dc 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -24,11 +24,21 @@ DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
-cache_counter = metrics.register_cache(
- "cache",
- lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
- labels=["name"],
-)
+# cache_counter = metrics.register_cache(
+# "cache",
+# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+# labels=["name"],
+# )
+
+
+def register_cache(name, cache):
+ caches_by_name[name] = cache
+ return metrics.register_cache(
+ "cache",
+ lambda: len(cache),
+ name,
+ )
+
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 758f5982b0..f31dfb22b7 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -22,7 +22,7 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
-from . import caches_by_name, DEBUG_CACHES, cache_counter
+from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
@@ -33,6 +33,7 @@ import functools
import inspect
import threading
+
logger = logging.getLogger(__name__)
@@ -43,6 +44,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class Cache(object):
+ __slots__ = (
+ "cache",
+ "max_entries",
+ "name",
+ "keylen",
+ "sequence",
+ "thread",
+ "metrics",
+ )
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
@@ -59,7 +69,7 @@ class Cache(object):
self.keylen = keylen
self.sequence = 0
self.thread = None
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -74,10 +84,10 @@ class Cache(object):
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return val
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
@@ -293,16 +303,21 @@ class CacheListDescriptor(object):
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
- cached = {}
+ results = {}
+ cached_defers = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
- res = cache.get(tuple(key)).observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
+ res = cache.get(tuple(key))
+ if not res.has_succeeded():
+ res = res.observe()
+ res.addCallback(lambda r, arg: (arg, r), arg)
+ cached_defers[arg] = res
+ else:
+ results[arg] = res.get_result()
except KeyError:
missing.append(arg)
@@ -340,12 +355,21 @@ class CacheListDescriptor(object):
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
-
- return preserve_context_over_deferred(defer.gatherResults(
- cached.values(),
- consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
+ cached_defers[arg] = res
+
+ if cached_defers:
+ def update_results_dict(res):
+ results.update(res)
+ return results
+
+ return preserve_context_over_deferred(defer.gatherResults(
+ cached_defers.values(),
+ consumeErrors=True,
+ ).addCallback(update_results_dict).addErrback(
+ unwrapFirstError
+ ))
+ else:
+ return results
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index f92d80542b..b0ca1bb79d 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,7 +15,7 @@
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
-from . import caches_by_name, cache_counter
+from . import register_cache
import threading
import logging
@@ -43,7 +43,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -58,7 +58,7 @@ class DictionaryCache(object):
def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
@@ -69,7 +69,7 @@ class DictionaryCache(object):
if k in entry.value
})
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return DictionaryEntry(False, {})
def invalidate(self, key):
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2b68c1ac93..080388958f 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
import logging
@@ -49,7 +49,7 @@ class ExpiringCache(object):
self._cache = {}
- caches_by_name[cache_name] = self._cache
+ self.metrics = register_cache(cache_name, self._cache)
def start(self):
if not self._expiry_ms:
@@ -78,9 +78,9 @@ class ExpiringCache(object):
def __getitem__(self, key):
try:
entry = self._cache[key]
- cache_counter.inc_hits(self._cache_name)
+ self.metrics.inc_hits()
except KeyError:
- cache_counter.inc_misses(self._cache_name)
+ self.metrics.inc_misses()
raise
if self._reset_expiry_on_get:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 36686b479e..00af539880 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -24,9 +24,12 @@ class ResponseCache(object):
used rather than trying to compute a new response.
"""
- def __init__(self):
+ def __init__(self, hs, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
+ self.clock = hs.get_clock()
+ self.timeout_sec = timeout_ms / 1000.
+
def get(self, key):
result = self.pending_result_cache.get(key)
if result is not None:
@@ -39,7 +42,13 @@ class ResponseCache(object):
self.pending_result_cache[key] = result
def remove(r):
- self.pending_result_cache.pop(key, None)
+ if self.timeout_sec:
+ self.clock.call_later(
+ self.timeout_sec,
+ self.pending_result_cache.pop, key, None,
+ )
+ else:
+ self.pending_result_cache.pop(key, None)
return r
result.addBoth(remove)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index ea8a74ca69..3c051dabc4 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
from blist import sorteddict
@@ -42,7 +42,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- caches_by_name[self.name] = self._cache
+ self.metrics = register_cache(self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
@@ -53,19 +53,19 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
if stream_pos < latest_entity_change_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
def get_entities_changed(self, entities, stream_pos):
@@ -82,10 +82,10 @@ class StreamChangeCache(object):
self._cache[k] for k in keys[i:]
).intersection(entities)
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
else:
result = entities
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return result
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index d7cccc06b1..e68f94ce77 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -27,10 +27,6 @@ import logging
logger = logging.getLogger(__name__)
-def registered_user(distributor, user):
- return distributor.fire("registered_user", user)
-
-
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 5316259d15..7a87045f87 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -317,7 +317,6 @@ def preserve_fn(f):
def g(*args, **kwargs):
with PreserveLoggingContext(current):
return f(*args, **kwargs)
-
return g
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index e1f374807e..76f301f549 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.logcontext import LoggingContext
import synapse.metrics
+from functools import wraps
import logging
@@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
)
+def measure_func(name):
+ def wrapper(func):
+ @wraps(func)
+ @defer.inlineCallbacks
+ def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, name):
+ r = yield func(self, *args, **kwargs)
+ defer.returnValue(r)
+ return measured_func
+ return wrapper
+
+
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
@@ -64,7 +78,6 @@ class Measure(object):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
if not self.start_context:
- logger.warn("Entered Measure without log context: %s", self.name)
self.start_context = LoggingContext("Measure")
self.start_context.__enter__()
self.created_context = True
@@ -84,8 +97,8 @@ class Measure(object):
if context != self.start_context:
logger.warn(
- "Context have unexpectedly changed from '%s' to '%s'. (%r)",
- context, self.start_context, self.name
+ "Context has unexpectedly changed from '%s' to '%s'. (%r)",
+ self.start_context, context, self.name
)
return
diff --git a/synapse/util/presentable_names.py b/synapse/util/presentable_names.py
index 3efa8a8206..f68676e9e7 100644
--- a/synapse/util/presentable_names.py
+++ b/synapse/util/presentable_names.py
@@ -14,6 +14,9 @@
# limitations under the License.
import re
+import logging
+
+logger = logging.getLogger(__name__)
# intentionally looser than what aliases we allow to be registered since
# other HSes may allow aliases that we would not
@@ -22,7 +25,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room"
-def calculate_room_name(room_state, user_id, fallback_to_members=True):
+def calculate_room_name(room_state, user_id, fallback_to_members=True,
+ fallback_to_single_member=True):
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
@@ -79,7 +83,10 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True):
):
if ("m.room.member", my_member_event.sender) in room_state:
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
- return "Invite from %s" % (name_from_member_event(inviter_member_event),)
+ if fallback_to_single_member:
+ return "Invite from %s" % (name_from_member_event(inviter_member_event),)
+ else:
+ return None
else:
return "Room Invite"
@@ -105,19 +112,29 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True):
# or inbound invite, or outbound 3PID invite.
if all_members[0].sender == user_id:
if "m.room.third_party_invite" in room_state_bytype:
- third_party_invites = room_state_bytype["m.room.third_party_invite"]
+ third_party_invites = (
+ room_state_bytype["m.room.third_party_invite"].values()
+ )
+
if len(third_party_invites) > 0:
# technically third party invite events are not member
# events, but they are close enough
- return "Inviting %s" (
- descriptor_from_member_events(third_party_invites)
- )
+
+ # FIXME: no they're not - they look nothing like a member;
+ # they have a great big encrypted thing as their name to
+ # prevent leaking the 3PID name...
+ # return "Inviting %s" % (
+ # descriptor_from_member_events(third_party_invites)
+ # )
+ return "Inviting email address"
else:
return ALL_ALONE
else:
return name_from_member_event(all_members[0])
else:
return ALL_ALONE
+ elif len(other_members) == 1 and not fallback_to_single_member:
+ return None
else:
return descriptor_from_member_events(other_members)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 43cf11f3f6..49527f4d21 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -128,7 +128,7 @@ class RetryDestinationLimiter(object):
)
valid_err_code = False
- if exc_type is CodeMessageException:
+ if exc_type is not None and issubclass(exc_type, CodeMessageException):
valid_err_code = 0 <= exc_val.code < 500
if exc_type is None or valid_err_code:
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index a4f156cb3b..52086df465 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -21,7 +21,7 @@ import logging
logger = logging.getLogger(__name__)
-def get_version_string(name, module):
+def get_version_string(module):
try:
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(module.__file__))
@@ -74,11 +74,11 @@ def get_version_string(name, module):
)
return (
- "%s/%s (%s)" % (
- name, module.__version__, git_version,
+ "%s (%s)" % (
+ module.__version__, git_version,
)
).encode("ascii")
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return ("%s/%s" % (name, module.__version__,)).encode("ascii")
+ return module.__version__.encode("ascii")
|