diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 3b9da5b34a..2b3f0bef3c 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -49,9 +50,6 @@ class Clock(object):
l.start(msec / 1000.0, now=False)
return l
- def stop_looping_call(self, loop):
- loop.stop()
-
def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
@@ -83,7 +81,7 @@ class Clock(object):
def timed_out_fn():
try:
- ret_deferred.errback(RuntimeError("Timed out"))
+ ret_deferred.errback(SynapseError(504, "Timed out"))
except:
pass
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 640fae3890..40be7fe7e3 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,7 +16,12 @@
from twisted.internet import defer, reactor
-from .logcontext import PreserveLoggingContext
+from .logcontext import (
+ PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+)
+from synapse.util import unwrapFirstError
+
+from contextlib import contextmanager
@defer.inlineCallbacks
@@ -97,6 +102,15 @@ class ObservableDeferred(object):
def observers(self):
return self._observers
+ def has_called(self):
+ return self._result is not None
+
+ def has_succeeded(self):
+ return self._result is not None and self._result[0] is True
+
+ def get_result(self):
+ return self._result[1]
+
def __getattr__(self, name):
return getattr(self._deferred, name)
@@ -107,3 +121,76 @@ class ObservableDeferred(object):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)
+
+
+def concurrently_execute(func, args, limit):
+ """Executes the function with each argument conncurrently while limiting
+ the number of concurrent executions.
+
+ Args:
+ func (func): Function to execute, should return a deferred.
+ args (list): List of arguments to pass to func, each invocation of func
+ gets a signle argument.
+ limit (int): Maximum number of conccurent executions.
+
+ Returns:
+ deferred: Resolved when all function invocations have finished.
+ """
+ it = iter(args)
+
+ @defer.inlineCallbacks
+ def _concurrently_execute_inner():
+ try:
+ while True:
+ yield func(it.next())
+ except StopIteration:
+ pass
+
+ return defer.gatherResults([
+ preserve_fn(_concurrently_execute_inner)()
+ for _ in xrange(limit)
+ ], consumeErrors=True).addErrback(unwrapFirstError)
+
+
+class Linearizer(object):
+ """Linearizes access to resources based on a key. Useful to ensure only one
+ thing is happening at a time on a given resource.
+
+ Example:
+
+ with (yield linearizer.queue("test_key")):
+ # do some work.
+
+ """
+ def __init__(self):
+ self.key_to_defer = {}
+
+ @defer.inlineCallbacks
+ def queue(self, key):
+ # If there is already a deferred in the queue, we pull it out so that
+ # we can wait on it later.
+ # Then we replace it with a deferred that we resolve *after* the
+ # context manager has exited.
+ # We only return the context manager after the previous deferred has
+ # resolved.
+ # This all has the net effect of creating a chain of deferreds that
+ # wait for the previous deferred before starting their work.
+ current_defer = self.key_to_defer.get(key)
+
+ new_defer = defer.Deferred()
+ self.key_to_defer[key] = new_defer
+
+ if current_defer:
+ yield preserve_context_over_deferred(current_defer)
+
+ @contextmanager
+ def _ctx_manager():
+ try:
+ yield
+ finally:
+ new_defer.callback(None)
+ current_d = self.key_to_defer.get(key)
+ if current_d is new_defer:
+ self.key_to_defer.pop(key, None)
+
+ defer.returnValue(_ctx_manager())
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index d53569ca49..ebd715c5dc 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -24,11 +24,21 @@ DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
-cache_counter = metrics.register_cache(
- "cache",
- lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
- labels=["name"],
-)
+# cache_counter = metrics.register_cache(
+# "cache",
+# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+# labels=["name"],
+# )
+
+
+def register_cache(name, cache):
+ caches_by_name[name] = cache
+ return metrics.register_cache(
+ "cache",
+ lambda: len(cache),
+ name,
+ )
+
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 35544b19fd..f31dfb22b7 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -22,7 +22,7 @@ from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
-from . import caches_by_name, DEBUG_CACHES, cache_counter
+from . import DEBUG_CACHES, register_cache
from twisted.internet import defer
@@ -33,6 +33,7 @@ import functools
import inspect
import threading
+
logger = logging.getLogger(__name__)
@@ -43,6 +44,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class Cache(object):
+ __slots__ = (
+ "cache",
+ "max_entries",
+ "name",
+ "keylen",
+ "sequence",
+ "thread",
+ "metrics",
+ )
def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
@@ -59,7 +69,7 @@ class Cache(object):
self.keylen = keylen
self.sequence = 0
self.thread = None
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -74,10 +84,10 @@ class Cache(object):
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return val
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
@@ -167,7 +177,8 @@ class CacheDescriptor(object):
% (orig.__name__,)
)
- self.cache = Cache(
+ def __get__(self, obj, objtype=None):
+ cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
@@ -175,14 +186,12 @@ class CacheDescriptor(object):
tree=self.tree,
)
- def __get__(self, obj, objtype=None):
-
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result_d = self.cache.get(cache_key)
+ cached_result_d = cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@@ -204,7 +213,7 @@ class CacheDescriptor(object):
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = self.cache.sequence
+ sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
@@ -213,20 +222,21 @@ class CacheDescriptor(object):
)
def onErr(f):
- self.cache.invalidate(cache_key)
+ cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- self.cache.update(sequence, cache_key, ret)
+ cache.update(sequence, cache_key, ret)
return preserve_context_over_deferred(ret.observe())
- wrapped.invalidate = self.cache.invalidate
- wrapped.invalidate_all = self.cache.invalidate_all
- wrapped.invalidate_many = self.cache.invalidate_many
- wrapped.prefill = self.cache.prefill
+ wrapped.invalidate = cache.invalidate
+ wrapped.invalidate_all = cache.invalidate_all
+ wrapped.invalidate_many = cache.invalidate_many
+ wrapped.prefill = cache.prefill
+ wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped
@@ -240,11 +250,12 @@ class CacheListDescriptor(object):
the list of missing keys to the wrapped fucntion.
"""
- def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
+ def __init__(self, orig, cached_method_name, list_name, num_args=1,
+ inlineCallbacks=False):
"""
Args:
orig (function)
- cache (Cache)
+ method_name (str); The name of the chached method.
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
@@ -263,7 +274,7 @@ class CacheListDescriptor(object):
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
- self.cache = cache
+ self.cached_method_name = cached_method_name
self.sentinel = object()
@@ -277,11 +288,13 @@ class CacheListDescriptor(object):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
- % (self.list_name, cache.name,)
+ % (self.list_name, cached_method_name,)
)
def __get__(self, obj, objtype=None):
+ cache = getattr(obj, self.cached_method_name).cache
+
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@@ -290,21 +303,26 @@ class CacheListDescriptor(object):
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
- cached = {}
+ results = {}
+ cached_defers = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
- res = self.cache.get(tuple(key)).observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
+ res = cache.get(tuple(key))
+ if not res.has_succeeded():
+ res = res.observe()
+ res.addCallback(lambda r, arg: (arg, r), arg)
+ cached_defers[arg] = res
+ else:
+ results[arg] = res.get_result()
except KeyError:
missing.append(arg)
if missing:
- sequence = self.cache.sequence
+ sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@@ -327,22 +345,31 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- self.cache.update(sequence, tuple(key), observer)
+ cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
- self.cache.invalidate(key)
+ cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
- cached[arg] = res
+ cached_defers[arg] = res
+
+ if cached_defers:
+ def update_results_dict(res):
+ results.update(res)
+ return results
- return preserve_context_over_deferred(defer.gatherResults(
- cached.values(),
- consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
+ return preserve_context_over_deferred(defer.gatherResults(
+ cached_defers.values(),
+ consumeErrors=True,
+ ).addCallback(update_results_dict).addErrback(
+ unwrapFirstError
+ ))
+ else:
+ return results
obj.__dict__[self.orig.__name__] = wrapped
@@ -370,7 +397,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
)
-def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -400,7 +427,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
"""
return lambda orig: CacheListDescriptor(
orig,
- cache=cache,
+ cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index f92d80542b..b0ca1bb79d 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,7 +15,7 @@
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
-from . import caches_by_name, cache_counter
+from . import register_cache
import threading
import logging
@@ -43,7 +43,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- caches_by_name[name] = self.cache
+ self.metrics = register_cache(name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -58,7 +58,7 @@ class DictionaryCache(object):
def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
@@ -69,7 +69,7 @@ class DictionaryCache(object):
if k in entry.value
})
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return DictionaryEntry(False, {})
def invalidate(self, key):
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2b68c1ac93..080388958f 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
import logging
@@ -49,7 +49,7 @@ class ExpiringCache(object):
self._cache = {}
- caches_by_name[cache_name] = self._cache
+ self.metrics = register_cache(cache_name, self._cache)
def start(self):
if not self._expiry_ms:
@@ -78,9 +78,9 @@ class ExpiringCache(object):
def __getitem__(self, key):
try:
entry = self._cache[key]
- cache_counter.inc_hits(self._cache_name)
+ self.metrics.inc_hits()
except KeyError:
- cache_counter.inc_misses(self._cache_name)
+ self.metrics.inc_misses()
raise
if self._reset_expiry_on_get:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
new file mode 100644
index 0000000000..36686b479e
--- /dev/null
+++ b/synapse/util/caches/response_cache.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.async import ObservableDeferred
+
+
+class ResponseCache(object):
+ """
+ This caches a deferred response. Until the deferred completes it will be
+ returned from the cache. This means that if the client retries the request
+ while the response is still being computed, that original response will be
+ used rather than trying to compute a new response.
+ """
+
+ def __init__(self):
+ self.pending_result_cache = {} # Requests that haven't finished yet.
+
+ def get(self, key):
+ result = self.pending_result_cache.get(key)
+ if result is not None:
+ return result.observe()
+ else:
+ return None
+
+ def set(self, key, deferred):
+ result = ObservableDeferred(deferred, consumeErrors=True)
+ self.pending_result_cache[key] = result
+
+ def remove(r):
+ self.pending_result_cache.pop(key, None)
+ return r
+
+ result.addBoth(remove)
+ return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index ea8a74ca69..3c051dabc4 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
from blist import sorteddict
@@ -42,7 +42,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- caches_by_name[self.name] = self._cache
+ self.metrics = register_cache(self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
@@ -53,19 +53,19 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos < self._earliest_known_stream_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
latest_entity_change_pos = self._entity_to_key.get(entity, None)
if latest_entity_change_pos is None:
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
if stream_pos < latest_entity_change_pos:
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return True
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
return False
def get_entities_changed(self, entities, stream_pos):
@@ -82,10 +82,10 @@ class StreamChangeCache(object):
self._cache[k] for k in keys[i:]
).intersection(entities)
- cache_counter.inc_hits(self.name)
+ self.metrics.inc_hits()
else:
result = entities
- cache_counter.inc_misses(self.name)
+ self.metrics.inc_misses()
return result
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 8875813de4..d7cccc06b1 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -15,7 +15,9 @@
from twisted.internet import defer
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_fn
+)
from synapse.util import unwrapFirstError
@@ -25,6 +27,24 @@ import logging
logger = logging.getLogger(__name__)
+def registered_user(distributor, user):
+ return distributor.fire("registered_user", user)
+
+
+def user_left_room(distributor, user, room_id):
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_left_room", user=user, room_id=room_id
+ )
+
+
+def user_joined_room(distributor, user, room_id):
+ return preserve_context_over_fn(
+ distributor.fire,
+ "user_joined_room", user=user, room_id=room_id
+ )
+
+
class Distributor(object):
"""A central dispatch point for loosely-connected pieces of code to
register, observe, and fire signals.
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
new file mode 100644
index 0000000000..45be47159a
--- /dev/null
+++ b/synapse/util/httpresourcetree.py
@@ -0,0 +1,98 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def create_resource_tree(desired_tree, root_resource):
+ """Create the resource tree for this Home Server.
+
+ This in unduly complicated because Twisted does not support putting
+ child resources more than 1 level deep at a time.
+
+ Args:
+ web_client (bool): True to enable the web client.
+ root_resource (twisted.web.resource.Resource): The root
+ resource to add the tree to.
+ Returns:
+ twisted.web.resource.Resource: the ``root_resource`` with a tree of
+ child resources added to it.
+ """
+
+ # ideally we'd just use getChild and putChild but getChild doesn't work
+ # unless you give it a Request object IN ADDITION to the name :/ So
+ # instead, we'll store a copy of this mapping so we can actually add
+ # extra resources to existing nodes. See self._resource_id for the key.
+ resource_mappings = {}
+ for full_path, res in desired_tree.items():
+ logger.info("Attaching %s to path %s", res, full_path)
+ last_resource = root_resource
+ for path_seg in full_path.split('/')[1:-1]:
+ if path_seg not in last_resource.listNames():
+ # resource doesn't exist, so make a "dummy resource"
+ child_resource = Resource()
+ last_resource.putChild(path_seg, child_resource)
+ res_id = _resource_id(last_resource, path_seg)
+ resource_mappings[res_id] = child_resource
+ last_resource = child_resource
+ else:
+ # we have an existing Resource, use that instead.
+ res_id = _resource_id(last_resource, path_seg)
+ last_resource = resource_mappings[res_id]
+
+ # ===========================
+ # now attach the actual desired resource
+ last_path_seg = full_path.split('/')[-1]
+
+ # if there is already a resource here, thieve its children and
+ # replace it
+ res_id = _resource_id(last_resource, last_path_seg)
+ if res_id in resource_mappings:
+ # there is a dummy resource at this path already, which needs
+ # to be replaced with the desired resource.
+ existing_dummy_resource = resource_mappings[res_id]
+ for child_name in existing_dummy_resource.listNames():
+ child_res_id = _resource_id(
+ existing_dummy_resource, child_name
+ )
+ child_resource = resource_mappings[child_res_id]
+ # steal the children
+ res.putChild(child_name, child_resource)
+
+ # finally, insert the desired resource in the right place
+ last_resource.putChild(last_path_seg, res)
+ res_id = _resource_id(last_resource, last_path_seg)
+ resource_mappings[res_id] = res
+
+ return root_resource
+
+
+def _resource_id(resource, path_seg):
+ """Construct an arbitrary resource ID so you can retrieve the mapping
+ later.
+
+ If you want to represent resource A putChild resource B with path C,
+ the mapping should looks like _resource_id(A,C) = B.
+
+ Args:
+ resource (Resource): The *parent* Resourceb
+ path_seg (str): The name of the child Resource to be attached.
+ Returns:
+ str: A unique string which can be a key to the child Resource.
+ """
+ return "%s-%s" % (resource, path_seg)
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
new file mode 100644
index 0000000000..97e0f00b67
--- /dev/null
+++ b/synapse/util/manhole.py
@@ -0,0 +1,70 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.conch.manhole import ColoredManhole
+from twisted.conch.insults import insults
+from twisted.conch import manhole_ssh
+from twisted.cred import checkers, portal
+from twisted.conch.ssh.keys import Key
+
+PUBLIC_KEY = (
+ "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az"
+ "64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJS"
+ "kbh/C+BR3utDS555mV"
+)
+
+PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY-----
+MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
+4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
+vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
+Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
+xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
+PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
+gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
+DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
+pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
+EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
+-----END RSA PRIVATE KEY-----"""
+
+
+def manhole(username, password, globals):
+ """Starts a ssh listener with password authentication using
+ the given username and password. Clients connecting to the ssh
+ listener will find themselves in a colored python shell with
+ the supplied globals.
+
+ Args:
+ username(str): The username ssh clients should auth with.
+ password(str): The password ssh clients should auth with.
+ globals(dict): The variables to expose in the shell.
+
+ Returns:
+ twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
+ """
+
+ checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
+ **{username: password}
+ )
+
+ rlm = manhole_ssh.TerminalRealm()
+ rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
+ ColoredManhole,
+ dict(globals, __name__="__console__")
+ )
+
+ factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
+ factory.publicKeys['ssh-rsa'] = Key.fromString(PUBLIC_KEY)
+ factory.privateKeys['ssh-rsa'] = Key.fromString(PRIVATE_KEY)
+
+ return factory
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c51b641125..e1f374807e 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -50,7 +50,7 @@ block_db_txn_duration = metrics.register_distribution(
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
- "ru_stime", "db_txn_count", "db_txn_duration"
+ "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
]
def __init__(self, clock, name):
@@ -58,14 +58,20 @@ class Measure(object):
self.name = name
self.start_context = None
self.start = None
+ self.created_context = False
def __enter__(self):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
- if self.start_context:
- self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
- self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration = self.start_context.db_txn_duration
+ if not self.start_context:
+ logger.warn("Entered Measure without log context: %s", self.name)
+ self.start_context = LoggingContext("Measure")
+ self.start_context.__enter__()
+ self.created_context = True
+
+ self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
+ self.db_txn_count = self.start_context.db_txn_count
+ self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None or not self.start_context:
@@ -91,7 +97,12 @@ class Measure(object):
block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
- block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
+ block_db_txn_count.inc_by(
+ context.db_txn_count - self.db_txn_count, self.name
+ )
block_db_txn_duration.inc_by(
context.db_txn_duration - self.db_txn_duration, self.name
)
+
+ if self.created_context:
+ self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/synapse/util/presentable_names.py b/synapse/util/presentable_names.py
new file mode 100644
index 0000000000..a6866f6117
--- /dev/null
+++ b/synapse/util/presentable_names.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import logging
+
+logger = logging.getLogger(__name__)
+
+# intentionally looser than what aliases we allow to be registered since
+# other HSes may allow aliases that we would not
+ALIAS_RE = re.compile(r"^#.*:.+$")
+
+ALL_ALONE = "Empty Room"
+
+
+def calculate_room_name(room_state, user_id, fallback_to_members=True):
+ """
+ Works out a user-facing name for the given room as per Matrix
+ spec recommendations.
+ Does not yet support internationalisation.
+ Args:
+ room_state: Dictionary of the room's state
+ user_id: The ID of the user to whom the room name is being presented
+ fallback_to_members: If False, return None instead of generating a name
+ based on the room's members if the room has no
+ title or aliases.
+
+ Returns:
+ (string or None) A human readable name for the room.
+ """
+ # does it have a name?
+ if ("m.room.name", "") in room_state:
+ m_room_name = room_state[("m.room.name", "")]
+ if m_room_name.content and m_room_name.content["name"]:
+ return m_room_name.content["name"]
+
+ # does it have a canonical alias?
+ if ("m.room.canonical_alias", "") in room_state:
+ canon_alias = room_state[("m.room.canonical_alias", "")]
+ if (
+ canon_alias.content and canon_alias.content["alias"] and
+ _looks_like_an_alias(canon_alias.content["alias"])
+ ):
+ return canon_alias.content["alias"]
+
+ # at this point we're going to need to search the state by all state keys
+ # for an event type, so rearrange the data structure
+ room_state_bytype = _state_as_two_level_dict(room_state)
+
+ # right then, any aliases at all?
+ if "m.room.aliases" in room_state_bytype:
+ m_room_aliases = room_state_bytype["m.room.aliases"]
+ if len(m_room_aliases.values()) > 0:
+ first_alias_event = m_room_aliases.values()[0]
+ if first_alias_event.content and first_alias_event.content["aliases"]:
+ the_aliases = first_alias_event.content["aliases"]
+ if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
+ return the_aliases[0]
+
+ if not fallback_to_members:
+ return None
+
+ my_member_event = None
+ if ("m.room.member", user_id) in room_state:
+ my_member_event = room_state[("m.room.member", user_id)]
+
+ if (
+ my_member_event is not None and
+ my_member_event.content['membership'] == "invite"
+ ):
+ if ("m.room.member", my_member_event.sender) in room_state:
+ inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
+ return "Invite from %s" % (name_from_member_event(inviter_member_event),)
+ else:
+ return "Room Invite"
+
+ # we're going to have to generate a name based on who's in the room,
+ # so find out who is in the room that isn't the user.
+ if "m.room.member" in room_state_bytype:
+ all_members = [
+ ev for ev in room_state_bytype["m.room.member"].values()
+ if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
+ ]
+ # Sort the member events oldest-first so the we name people in the
+ # order the joined (it should at least be deterministic rather than
+ # dictionary iteration order)
+ all_members.sort(key=lambda e: e.origin_server_ts)
+ other_members = [m for m in all_members if m.state_key != user_id]
+ else:
+ other_members = []
+ all_members = []
+
+ if len(other_members) == 0:
+ if len(all_members) == 1:
+ # self-chat, peeked room with 1 participant,
+ # or inbound invite, or outbound 3PID invite.
+ if all_members[0].sender == user_id:
+ if "m.room.third_party_invite" in room_state_bytype:
+ third_party_invites = (
+ room_state_bytype["m.room.third_party_invite"].values()
+ )
+
+ if len(third_party_invites) > 0:
+ # technically third party invite events are not member
+ # events, but they are close enough
+
+ # FIXME: no they're not - they look nothing like a member;
+ # they have a great big encrypted thing as their name to
+ # prevent leaking the 3PID name...
+ # return "Inviting %s" % (
+ # descriptor_from_member_events(third_party_invites)
+ # )
+ return "Inviting email address"
+ else:
+ return ALL_ALONE
+ else:
+ return name_from_member_event(all_members[0])
+ else:
+ return ALL_ALONE
+ else:
+ return descriptor_from_member_events(other_members)
+
+
+def descriptor_from_member_events(member_events):
+ if len(member_events) == 0:
+ return "nobody"
+ elif len(member_events) == 1:
+ return name_from_member_event(member_events[0])
+ elif len(member_events) == 2:
+ return "%s and %s" % (
+ name_from_member_event(member_events[0]),
+ name_from_member_event(member_events[1]),
+ )
+ else:
+ return "%s and %d others" % (
+ name_from_member_event(member_events[0]),
+ len(member_events) - 1,
+ )
+
+
+def name_from_member_event(member_event):
+ if (
+ member_event.content and "displayname" in member_event.content and
+ member_event.content["displayname"]
+ ):
+ return member_event.content["displayname"]
+ return member_event.state_key
+
+
+def _state_as_two_level_dict(state):
+ ret = {}
+ for k, v in state.items():
+ ret.setdefault(k[0], {})[k[1]] = v
+ return ret
+
+
+def _looks_like_an_alias(string):
+ return ALIAS_RE.match(string) is not None
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 4076eed269..1101881a2d 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -100,20 +100,6 @@ class _PerHostRatelimiter(object):
self.current_processing = set()
self.request_times = []
- def is_empty(self):
- time_now = self.clock.time_msec()
- self.request_times[:] = [
- r for r in self.request_times
- if time_now - r < self.window_size
- ]
-
- return not (
- self.ready_request_queue
- or self.sleeping_requests
- or self.current_processing
- or self.request_times
- )
-
@contextlib.contextmanager
def ratelimit(self):
# `contextlib.contextmanager` takes a generator and turns it into a
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
new file mode 100644
index 0000000000..f4a9abf83f
--- /dev/null
+++ b/synapse/util/rlimit.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import resource
+import logging
+
+
+logger = logging.getLogger("synapse.app.homeserver")
+
+
+def change_resource_limit(soft_file_no):
+ try:
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+
+ if not soft_file_no:
+ soft_file_no = hard
+
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
+ logger.info("Set file limit to: %d", soft_file_no)
+
+ resource.setrlimit(
+ resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
+ )
+ except (ValueError, resource.error) as e:
+ logger.warn("Failed to set file or core limit: %s", e)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index b490bb8725..a100f151d4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -21,10 +21,6 @@ _string_with_symbols = (
)
-def origin_from_ucid(ucid):
- return ucid.split("@", 1)[1]
-
-
def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
new file mode 100644
index 0000000000..a4f156cb3b
--- /dev/null
+++ b/synapse/util/versionstring.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import subprocess
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def get_version_string(name, module):
+ try:
+ null = open(os.devnull, 'w')
+ cwd = os.path.dirname(os.path.abspath(module.__file__))
+ try:
+ git_branch = subprocess.check_output(
+ ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
+ stderr=null,
+ cwd=cwd,
+ ).strip()
+ git_branch = "b=" + git_branch
+ except subprocess.CalledProcessError:
+ git_branch = ""
+
+ try:
+ git_tag = subprocess.check_output(
+ ['git', 'describe', '--exact-match'],
+ stderr=null,
+ cwd=cwd,
+ ).strip()
+ git_tag = "t=" + git_tag
+ except subprocess.CalledProcessError:
+ git_tag = ""
+
+ try:
+ git_commit = subprocess.check_output(
+ ['git', 'rev-parse', '--short', 'HEAD'],
+ stderr=null,
+ cwd=cwd,
+ ).strip()
+ except subprocess.CalledProcessError:
+ git_commit = ""
+
+ try:
+ dirty_string = "-this_is_a_dirty_checkout"
+ is_dirty = subprocess.check_output(
+ ['git', 'describe', '--dirty=' + dirty_string],
+ stderr=null,
+ cwd=cwd,
+ ).strip().endswith(dirty_string)
+
+ git_dirty = "dirty" if is_dirty else ""
+ except subprocess.CalledProcessError:
+ git_dirty = ""
+
+ if git_branch or git_tag or git_commit or git_dirty:
+ git_version = ",".join(
+ s for s in
+ (git_branch, git_tag, git_commit, git_dirty,)
+ if s
+ )
+
+ return (
+ "%s/%s (%s)" % (
+ name, module.__version__, git_version,
+ )
+ ).encode("ascii")
+ except Exception as e:
+ logger.info("Failed to check for git repository: %s", e)
+
+ return ("%s/%s" % (name, module.__version__,)).encode("ascii")
|