diff --git a/demo/start.sh b/demo/start.sh
index 5b3daef57f..b9cc14b9d2 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -31,6 +31,7 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config
python -m synapse.app.homeserver \
--generate-config \
+ --enable_registration \
-H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 4911f0896b..24f15f3154 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+ preserve_context_over_fn, preserve_context_over_deferred
+)
import simplejson as json
import logging
@@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5):
try:
- with PreserveLoggingContext():
- protocol = yield endpoint.connect(factory)
- server_response, server_certificate = yield protocol.remote_key
- defer.returnValue((server_response, server_certificate))
- return
+ protocol = yield preserve_context_over_fn(
+ endpoint.connect, factory
+ )
+ server_response, server_certificate = yield preserve_context_over_deferred(
+ protocol.remote_key
+ )
+ defer.returnValue((server_response, server_certificate))
+ return
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2b46188c91..cd79e23f4b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -20,7 +20,6 @@ from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import synapse.metrics
@@ -123,29 +122,28 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- with PreserveLoggingContext():
- results = []
-
- for pdu in pdu_list:
- d = self._handle_new_pdu(transaction.origin, pdu)
-
- try:
- yield d
- results.append({})
- except FederationError as e:
- self.send_failure(e, transaction.origin)
- results.append({"error": str(e)})
- except Exception as e:
- results.append({"error": str(e)})
- logger.exception("Failed to handle PDU")
-
- if hasattr(transaction, "edus"):
- for edu in [Edu(**x) for x in transaction.edus]:
- self.received_edu(
- transaction.origin,
- edu.edu_type,
- edu.content
- )
+ results = []
+
+ for pdu in pdu_list:
+ d = self._handle_new_pdu(transaction.origin, pdu)
+
+ try:
+ yield d
+ results.append({})
+ except FederationError as e:
+ self.send_failure(e, transaction.origin)
+ results.append({"error": str(e)})
+ except Exception as e:
+ results.append({"error": str(e)})
+ logger.exception("Failed to handle PDU")
+
+ if hasattr(transaction, "edus"):
+ for edu in [Edu(**x) for x in transaction.edus]:
+ self.received_edu(
+ transaction.origin,
+ edu.edu_type,
+ edu.content
+ )
for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f9f855213b..993d33ba47 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -15,7 +15,6 @@
from twisted.internet import defer
-from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
@@ -81,10 +80,9 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
- with PreserveLoggingContext():
- events, tokens = yield self.notifier.get_events_for(
- auth_user, room_ids, pagin_config, timeout
- )
+ events, tokens = yield self.notifier.get_events_for(
+ auth_user, room_ids, pagin_config, timeout
+ )
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 9e15610401..6ae39a1d37 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -19,7 +19,6 @@ from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import synapse.metrics
@@ -278,15 +277,14 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap
- with PreserveLoggingContext():
- if now_online and not was_polling:
- self.start_polling_presence(target_user, state=state)
- elif not now_online and was_polling:
- self.stop_polling_presence(target_user)
+ if now_online and not was_polling:
+ self.start_polling_presence(target_user, state=state)
+ elif not now_online and was_polling:
+ self.stop_polling_presence(target_user)
- # TODO(paul): perform a presence push as part of start/stop poll so
- # we don't have to do this all the time
- self.changed_presencelike_data(target_user, state)
+ # TODO(paul): perform a presence push as part of start/stop poll so
+ # we don't have to do this all the time
+ self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None):
if now is None:
@@ -408,10 +406,10 @@ class PresenceHandler(BaseHandler):
yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string()
)
- with PreserveLoggingContext():
- self.start_polling_presence(
- observer_user, target_user=observed_user
- )
+
+ self.start_polling_presence(
+ observer_user, target_user=observed_user
+ )
@defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user):
@@ -430,10 +428,9 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
- with PreserveLoggingContext():
- self.stop_polling_presence(
- observer_user, target_user=observed_user
- )
+ self.stop_polling_presence(
+ observer_user, target_user=observed_user
+ )
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
@@ -766,8 +763,7 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]:
del self._remote_sendmap[user]
- with PreserveLoggingContext():
- yield defer.DeferredList(deferreds, consumeErrors=True)
+ yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index ee2732b848..a7de7a80f8 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -154,14 +154,13 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user):
defer.returnValue(None)
- with PreserveLoggingContext():
- (displayname, avatar_url) = yield defer.gatherResults(
- [
- self.store.get_profile_displayname(user.localpart),
- self.store.get_profile_avatar_url(user.localpart),
- ],
- consumeErrors=True
- )
+ (displayname, avatar_url) = yield defer.gatherResults(
+ [
+ self.store.get_profile_displayname(user.localpart),
+ self.store.get_profile_avatar_url(user.localpart),
+ ],
+ consumeErrors=True
+ )
state["displayname"] = displayname
state["avatar_url"] = avatar_url
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e8a5dedab4..5b3cefb2dc 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.api.errors import CodeMessageException
+from synapse.util.logcontext import preserve_context_over_fn
from syutil.jsonutil import encode_canonical_json
import synapse.metrics
@@ -61,7 +62,10 @@ class SimpleHttpClient(object):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.inc(method)
- d = self.agent.request(method, *args, **kwargs)
+ d = preserve_context_over_fn(
+ self.agent.request,
+ method, *args, **kwargs
+ )
def _cb(response):
incoming_responses_counter.inc(method, response.code)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7fa295cad5..c99d237c73 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics
from syutil.jsonutil import encode_canonical_json
@@ -144,22 +144,22 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, url_bytes, headers_dict)
try:
- with PreserveLoggingContext():
- request_deferred = self.agent.request(
- destination,
- endpoint,
- method,
- path_bytes,
- param_bytes,
- query_bytes,
- Headers(headers_dict),
- producer
- )
+ request_deferred = preserve_context_over_fn(
+ self.agent.request,
+ destination,
+ endpoint,
+ method,
+ path_bytes,
+ param_bytes,
+ query_bytes,
+ Headers(headers_dict),
+ producer
+ )
- response = yield self.clock.time_bound_deferred(
- request_deferred,
- time_out=60,
- )
+ response = yield self.clock.time_bound_deferred(
+ request_deferred,
+ time_out=60,
+ )
logger.debug("Got response to %s", method)
break
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 78eb28e4b2..fbbccb38e6 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_deferred
from synapse.types import StreamToken
import synapse.metrics
@@ -223,11 +223,10 @@ class Notifier(object):
def eb(failure):
logger.exception("Failed to notify listener", failure)
- with PreserveLoggingContext():
- yield defer.DeferredList(
+ yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
- )
+ )
@defer.inlineCallbacks
@log_function
@@ -298,11 +297,10 @@ class Notifier(object):
failure.getTracebackObject())
)
- with PreserveLoggingContext():
- yield defer.DeferredList(
- [notify(l).addErrback(eb) for l in listeners],
- consumeErrors=True,
- )
+ yield defer.DeferredList(
+ [notify(l).addErrback(eb) for l in listeners],
+ consumeErrors=True,
+ )
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, filter, timeout, callback):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ee5587c721..b0020f51db 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,7 +18,7 @@ from synapse.api.errors import StoreError
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
@@ -419,10 +419,11 @@ class SQLBaseStore(object):
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield preserve_context_over_fn(
+ self._db_pool.runWithConnection,
+ inner_func, *args, **kwargs
+ )
+
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 79109d0b19..364b927851 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -50,8 +50,10 @@ class Clock(object):
current_context = LoggingContext.current_context()
def wrapped_callback():
- LoggingContext.thread_local.current_context = current_context
- callback()
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ callback()
+
return reactor.callLater(delay, wrapped_callback)
def cancel_call_later(self, timer):
diff --git a/synapse/util/async.py b/synapse/util/async.py
index d8febdb90c..f78395a431 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,15 +16,13 @@
from twisted.internet import defer, reactor
-from .logcontext import PreserveLoggingContext
+from .logcontext import preserve_context_over_deferred
-@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
- with PreserveLoggingContext():
- yield d
+ return preserve_context_over_deferred(d)
def run_on_reactor():
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index da7872e95d..192e3f49f0 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
import threading
import logging
@@ -129,3 +131,32 @@ class PreserveLoggingContext(object):
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context
+
+
+def preserve_context_over_fn(fn, *args, **kwargs):
+ with PreserveLoggingContext():
+ deferred = fn(*args, **kwargs)
+
+ return preserve_context_over_deferred(deferred)
+
+
+def preserve_context_over_deferred(deferred):
+ d = defer.Deferred()
+
+ current_context = LoggingContext.current_context()
+
+ def cb(res):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.callback(res)
+ return res
+
+ def eb(failure):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = current_context
+ res = d.errback(failure)
+ return res
+
+ deferred.addCallbacks(cb, eb)
+
+ return d
|