diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 36d01a10af..f83c05df44 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -115,7 +115,7 @@ class EmailConfig(Config):
missing.append("email." + k)
if config.get("public_baseurl") is None:
- missing.append("public_base_url")
+ missing.append("public_baseurl")
if len(missing) > 0:
raise RuntimeError(
diff --git a/synapse/config/key.py b/synapse/config/key.py
index fe8386985c..ba2199bceb 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -76,7 +76,7 @@ class KeyConfig(Config):
config_dir_path, config["server_name"] + ".signing.key"
)
- self.signing_key = self.read_signing_key(signing_key_path)
+ self.signing_key = self.read_signing_keys(signing_key_path, "signing_key")
self.old_signing_keys = self.read_old_signing_keys(
config.get("old_signing_keys", {})
@@ -85,6 +85,14 @@ class KeyConfig(Config):
config.get("key_refresh_interval", "1d")
)
+ key_server_signing_keys_path = config.get("key_server_signing_keys_path")
+ if key_server_signing_keys_path:
+ self.key_server_signing_keys = self.read_signing_keys(
+ key_server_signing_keys_path, "key_server_signing_keys_path"
+ )
+ else:
+ self.key_server_signing_keys = list(self.signing_key)
+
# if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config:
key_servers = [{"server_name": "matrix.org"}]
@@ -210,16 +218,34 @@ class KeyConfig(Config):
#
#trusted_key_servers:
# - server_name: "matrix.org"
+ #
+
+ # The signing keys to use when acting as a trusted key server. If not specified
+ # defaults to the server signing key.
+ #
+ # Can contain multiple keys, one per line.
+ #
+ #key_server_signing_keys_path: "key_server_signing_keys.key"
"""
% locals()
)
- def read_signing_key(self, signing_key_path):
- signing_keys = self.read_file(signing_key_path, "signing_key")
+ def read_signing_keys(self, signing_key_path, name):
+ """Read the signing keys in the given path.
+
+ Args:
+ signing_key_path (str)
+ name (str): Associated config key name
+
+ Returns:
+ list[SigningKey]
+ """
+
+ signing_keys = self.read_file(signing_key_path, name)
try:
return read_signing_keys(signing_keys.splitlines(True))
except Exception as e:
- raise ConfigError("Error reading signing_key: %s" % (str(e)))
+ raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys(self, old_signing_keys):
keys = {}
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 654accc843..7cfad192e8 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -29,7 +29,6 @@ from signedjson.key import (
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
- sign_json,
signature_ids,
verify_signed_json,
)
@@ -539,13 +538,7 @@ class BaseV2KeyFetcher(object):
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
- # re-sign the json with our own key, so that it is ready if we are asked to
- # give it out as a notary server
- signed_key_json = sign_json(
- response_json, self.config.server_name, self.config.signing_key[0]
- )
-
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
+ key_json_bytes = encode_canonical_json(response_json)
yield make_deferred_yieldable(
defer.gatherResults(
@@ -557,7 +550,7 @@ class BaseV2KeyFetcher(object):
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
+ key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9286ca3202..05fd49f3c1 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -43,7 +43,7 @@ from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.logging.context import nested_logging_context
-from synapse.logging.opentracing import log_kv, trace
+from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
@@ -811,12 +811,13 @@ class FederationHandlerRegistry(object):
if not handler:
logger.warn("No handler registered for EDU type %s", edu_type)
- try:
- yield handler(origin, content)
- except SynapseError as e:
- logger.info("Failed to handle edu %r: %r", edu_type, e)
- except Exception:
- logger.exception("Failed to handle edu %r", edu_type)
+ with start_active_span_from_edu(content, "handle_edu"):
+ try:
+ yield handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception:
+ logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 52706302f2..62ca6a3e87 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -14,11 +14,19 @@
# limitations under the License.
import logging
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.errors import HttpResponseException
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Transaction
+from synapse.logging.opentracing import (
+ extract_text_map,
+ set_tag,
+ start_active_span_follows_from,
+ tags,
+)
from synapse.util.metrics import measure_func
logger = logging.getLogger(__name__)
@@ -44,93 +52,109 @@ class TransactionManager(object):
@defer.inlineCallbacks
def send_new_transaction(self, destination, pending_pdus, pending_edus):
- # Sort based on the order field
- pending_pdus.sort(key=lambda t: t[1])
- pdus = [x[0] for x in pending_pdus]
- edus = pending_edus
+ # Make a transaction-sending opentracing span. This span follows on from
+ # all the edus in that transaction. This needs to be done since there is
+ # no active span here, so if the edus were not received by the remote the
+ # span would have no causality and it would be forgotten.
+ # The span_contexts is a generator so that it won't be evaluated if
+ # opentracing is disabled. (Yay speed!)
- success = True
+ span_contexts = (
+ extract_text_map(json.loads(edu.get_context())) for edu in pending_edus
+ )
- logger.debug("TX [%s] _attempt_new_transaction", destination)
+ with start_active_span_follows_from("send_transaction", span_contexts):
- txn_id = str(self._next_txn_id)
+ # Sort based on the order field
+ pending_pdus.sort(key=lambda t: t[1])
+ pdus = [x[0] for x in pending_pdus]
+ edus = pending_edus
- logger.debug(
- "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
- destination,
- txn_id,
- len(pdus),
- len(edus),
- )
+ success = True
- transaction = Transaction.create_new(
- origin_server_ts=int(self.clock.time_msec()),
- transaction_id=txn_id,
- origin=self._server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- )
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
- self._next_txn_id += 1
+ txn_id = str(self._next_txn_id)
- logger.info(
- "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
- destination,
- txn_id,
- transaction.transaction_id,
- len(pdus),
- len(edus),
- )
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
+ destination,
+ txn_id,
+ len(pdus),
+ len(edus),
+ )
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self.clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self._transport_layer.send_transaction(
- transaction, json_data_cb
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self._server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
)
- code = 200
- except HttpResponseException as e:
- code = e.code
- response = e.response
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
- raise e
+ self._next_txn_id += 1
- logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+ logger.info(
+ "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
+ destination,
+ txn_id,
+ transaction.transaction_id,
+ len(pdus),
+ len(edus),
+ )
- if code == 200:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self._transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
+
+ if e.code in (401, 404, 429) or 500 <= e.code:
+ logger.info(
+ "TX [%s] {%s} got %d response", destination, txn_id, code
+ )
+ raise e
+
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+
+ if code == 200:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warn(
+ "TX [%s] {%s} Remote returned error for %s: %s",
+ destination,
+ txn_id,
+ e_id,
+ r,
+ )
+ else:
+ for p in pdus:
logger.warn(
- "TX [%s] {%s} Remote returned error for %s: %s",
+ "TX [%s] {%s} Failed to send event %s",
destination,
txn_id,
- e_id,
- r,
+ p.event_id,
)
- else:
- for p in pdus:
- logger.warn(
- "TX [%s] {%s} Failed to send event %s",
- destination,
- txn_id,
- p.event_id,
- )
- success = False
+ success = False
- return success
+ set_tag(tags.ERROR, not success)
+ return success
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a17148fc3c..dc53b4b170 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -38,7 +38,12 @@ from synapse.http.servlet import (
parse_string_from_args,
)
from synapse.logging.context import run_in_background
-from synapse.logging.opentracing import start_active_span_from_context, tags
+from synapse.logging.opentracing import (
+ start_active_span,
+ start_active_span_from_request,
+ tags,
+ whitelisted_homeserver,
+)
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
@@ -288,20 +293,28 @@ class BaseFederationServlet(object):
logger.warn("authenticate_request failed: %s", e)
raise
- # Start an opentracing span
- with start_active_span_from_context(
- request.requestHeaders,
- "incoming-federation-request",
- tags={
- "request_id": request.get_request_id(),
- tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
- tags.HTTP_METHOD: request.get_method(),
- tags.HTTP_URL: request.get_redacted_uri(),
- tags.PEER_HOST_IPV6: request.getClientIP(),
- "authenticated_entity": origin,
- "servlet_name": request.request_metrics.name,
- },
- ):
+ request_tags = {
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ "authenticated_entity": origin,
+ "servlet_name": request.request_metrics.name,
+ }
+
+ # Only accept the span context if the origin is authenticated
+ # and whitelisted
+ if origin and whitelisted_homeserver(origin):
+ scope = start_active_span_from_request(
+ request, "incoming-federation-request", tags=request_tags
+ )
+ else:
+ scope = start_active_span(
+ "incoming-federation-request", tags=request_tags
+ )
+
+ with scope:
if origin:
with ratelimiter.ratelimit(origin) as d:
await d
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 14aad8f09d..aa84621206 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -38,6 +38,9 @@ class Edu(JsonEncodedObject):
internal_keys = ["origin", "destination"]
+ def get_context(self):
+ return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
+
class Transaction(JsonEncodedObject):
""" A transaction is a list of Pdus and Edus to be sent to a remote home
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 2f22f56ca4..d30a68b650 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -94,6 +94,16 @@ class AdminHandler(BaseHandler):
return ret
+ def set_user_server_admin(self, user, admin):
+ """
+ Set the admin bit on a user.
+
+ Args:
+ user_id (UserID): the (necessarily local) user to manipulate
+ admin (bool): whether or not the user should be an admin of this server
+ """
+ return self.store.set_server_admin(user, admin)
+
@defer.inlineCallbacks
def export_user_data(self, user_id, writer):
"""Write all data we have on the user to the given writer.
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index e1ebb6346c..c7d56779b8 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,9 +15,17 @@
import logging
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.logging.opentracing import (
+ get_active_span_text_map,
+ set_tag,
+ start_active_span,
+ whitelisted_homeserver,
+)
from synapse.types import UserID, get_domain_from_id
from synapse.util.stringutils import random_string
@@ -100,14 +108,21 @@ class DeviceMessageHandler(object):
message_id = random_string(16)
+ context = get_active_span_text_map()
+
remote_edu_contents = {}
for destination, messages in remote_messages.items():
- remote_edu_contents[destination] = {
- "messages": messages,
- "sender": sender_user_id,
- "type": message_type,
- "message_id": message_id,
- }
+ with start_active_span("to_device_for_user"):
+ set_tag("destination", destination)
+ remote_edu_contents[destination] = {
+ "messages": messages,
+ "sender": sender_user_id,
+ "type": message_type,
+ "message_id": message_id,
+ "org.matrix.opentracing_context": json.dumps(context)
+ if whitelisted_homeserver(destination)
+ else None,
+ }
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 2cc237e6a5..8690f69d45 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -34,7 +34,7 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-MAX_DISPLAYNAME_LEN = 100
+MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fd07bf7b8e..c186b31f59 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -300,7 +300,7 @@ class RestServlet(object):
http_server.register_paths(
method,
patterns,
- trace_servlet(servlet_classname, method_handler),
+ trace_servlet(servlet_classname)(method_handler),
servlet_classname,
)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 6b706e1892..dd296027a1 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -149,6 +149,9 @@ unchartered waters will require the enforcement of the whitelist.
``logging/opentracing.py`` has a ``whitelisted_homeserver`` method which takes
in a destination and compares it to the whitelist.
+Most injection methods take a 'destination' arg. The context will only be injected
+if the destination matches the whitelist or the destination is None.
+
=======
Gotchas
=======
@@ -174,10 +177,48 @@ from twisted.internet import defer
from synapse.config import ConfigError
+# Helper class
+
+
+class _DummyTagNames(object):
+ """wrapper of opentracings tags. We need to have them if we
+ want to reference them without opentracing around. Clearly they
+ should never actually show up in a trace. `set_tags` overwrites
+ these with the correct ones."""
+
+ INVALID_TAG = "invalid-tag"
+ COMPONENT = INVALID_TAG
+ DATABASE_INSTANCE = INVALID_TAG
+ DATABASE_STATEMENT = INVALID_TAG
+ DATABASE_TYPE = INVALID_TAG
+ DATABASE_USER = INVALID_TAG
+ ERROR = INVALID_TAG
+ HTTP_METHOD = INVALID_TAG
+ HTTP_STATUS_CODE = INVALID_TAG
+ HTTP_URL = INVALID_TAG
+ MESSAGE_BUS_DESTINATION = INVALID_TAG
+ PEER_ADDRESS = INVALID_TAG
+ PEER_HOSTNAME = INVALID_TAG
+ PEER_HOST_IPV4 = INVALID_TAG
+ PEER_HOST_IPV6 = INVALID_TAG
+ PEER_PORT = INVALID_TAG
+ PEER_SERVICE = INVALID_TAG
+ SAMPLING_PRIORITY = INVALID_TAG
+ SERVICE = INVALID_TAG
+ SPAN_KIND = INVALID_TAG
+ SPAN_KIND_CONSUMER = INVALID_TAG
+ SPAN_KIND_PRODUCER = INVALID_TAG
+ SPAN_KIND_RPC_CLIENT = INVALID_TAG
+ SPAN_KIND_RPC_SERVER = INVALID_TAG
+
+
try:
import opentracing
+
+ tags = opentracing.tags
except ImportError:
opentracing = None
+ tags = _DummyTagNames
try:
from jaeger_client import Config as JaegerConfig
from synapse.logging.scopecontextmanager import LogContextScopeManager
@@ -252,10 +293,6 @@ def init_tracer(config):
scope_manager=LogContextScopeManager(config),
).initialize_tracer()
- # Set up tags to be opentracing's tags
- global tags
- tags = opentracing.tags
-
# Whitelisting
@@ -334,8 +371,8 @@ def start_active_span_follows_from(operation_name, contexts):
return scope
-def start_active_span_from_context(
- headers,
+def start_active_span_from_request(
+ request,
operation_name,
references=None,
tags=None,
@@ -344,9 +381,9 @@ def start_active_span_from_context(
finish_on_close=True,
):
"""
- Extracts a span context from Twisted Headers.
+ Extracts a span context from a Twisted Request.
args:
- headers (twisted.web.http_headers.Headers)
+ headers (twisted.web.http.Request)
For the other args see opentracing.tracer
@@ -360,7 +397,9 @@ def start_active_span_from_context(
if opentracing is None:
return _noop_context_manager()
- header_dict = {k.decode(): v[0].decode() for k, v in headers.getAllRawHeaders()}
+ header_dict = {
+ k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+ }
context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
return opentracing.tracer.start_active_span(
@@ -448,7 +487,7 @@ def set_operation_name(operation_name):
@only_if_tracing
-def inject_active_span_twisted_headers(headers, destination):
+def inject_active_span_twisted_headers(headers, destination, check_destination=True):
"""
Injects a span context into twisted headers in-place
@@ -467,7 +506,7 @@ def inject_active_span_twisted_headers(headers, destination):
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
- if not whitelisted_homeserver(destination):
+ if check_destination and not whitelisted_homeserver(destination):
return
span = opentracing.tracer.active_span
@@ -479,7 +518,7 @@ def inject_active_span_twisted_headers(headers, destination):
@only_if_tracing
-def inject_active_span_byte_dict(headers, destination):
+def inject_active_span_byte_dict(headers, destination, check_destination=True):
"""
Injects a span context into a dict where the headers are encoded as byte
strings
@@ -511,7 +550,7 @@ def inject_active_span_byte_dict(headers, destination):
@only_if_tracing
-def inject_active_span_text_map(carrier, destination=None):
+def inject_active_span_text_map(carrier, destination, check_destination=True):
"""
Injects a span context into a dict
@@ -532,7 +571,7 @@ def inject_active_span_text_map(carrier, destination=None):
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
- if destination and not whitelisted_homeserver(destination):
+ if check_destination and not whitelisted_homeserver(destination):
return
opentracing.tracer.inject(
@@ -540,6 +579,29 @@ def inject_active_span_text_map(carrier, destination=None):
)
+def get_active_span_text_map(destination=None):
+ """
+ Gets a span context as a dict. This can be used instead of manually
+ injecting a span into an empty carrier.
+
+ Args:
+ destination (str): the name of the remote server.
+
+ Returns:
+ dict: the active span's context if opentracing is enabled, otherwise empty.
+ """
+
+ if not opentracing or (destination and not whitelisted_homeserver(destination)):
+ return {}
+
+ carrier = {}
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+
+ return carrier
+
+
def active_span_context_as_string():
"""
Returns:
@@ -689,65 +751,43 @@ def tag_args(func):
return _tag_args_inner
-def trace_servlet(servlet_name, func):
+def trace_servlet(servlet_name, extract_context=False):
"""Decorator which traces a serlet. It starts a span with some servlet specific
- tags such as the servlet_name and request information"""
- if not opentracing:
- return func
+ tags such as the servlet_name and request information
- @wraps(func)
- @defer.inlineCallbacks
- def _trace_servlet_inner(request, *args, **kwargs):
- with start_active_span(
- "incoming-client-request",
- tags={
+ Args:
+ servlet_name (str): The name to be used for the span's operation_name
+ extract_context (bool): Whether to attempt to extract the opentracing
+ context from the request the servlet is handling.
+
+ """
+
+ def _trace_servlet_inner_1(func):
+ if not opentracing:
+ return func
+
+ @wraps(func)
+ @defer.inlineCallbacks
+ def _trace_servlet_inner(request, *args, **kwargs):
+ request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
- "servlet_name": servlet_name,
- },
- ):
- result = yield defer.maybeDeferred(func, request, *args, **kwargs)
- return result
-
- return _trace_servlet_inner
-
-
-# Helper class
+ }
+ if extract_context:
+ scope = start_active_span_from_request(
+ request, servlet_name, tags=request_tags
+ )
+ else:
+ scope = start_active_span(servlet_name, tags=request_tags)
-class _DummyTagNames(object):
- """wrapper of opentracings tags. We need to have them if we
- want to reference them without opentracing around. Clearly they
- should never actually show up in a trace. `set_tags` overwrites
- these with the correct ones."""
-
- INVALID_TAG = "invalid-tag"
- COMPONENT = INVALID_TAG
- DATABASE_INSTANCE = INVALID_TAG
- DATABASE_STATEMENT = INVALID_TAG
- DATABASE_TYPE = INVALID_TAG
- DATABASE_USER = INVALID_TAG
- ERROR = INVALID_TAG
- HTTP_METHOD = INVALID_TAG
- HTTP_STATUS_CODE = INVALID_TAG
- HTTP_URL = INVALID_TAG
- MESSAGE_BUS_DESTINATION = INVALID_TAG
- PEER_ADDRESS = INVALID_TAG
- PEER_HOSTNAME = INVALID_TAG
- PEER_HOST_IPV4 = INVALID_TAG
- PEER_HOST_IPV6 = INVALID_TAG
- PEER_PORT = INVALID_TAG
- PEER_SERVICE = INVALID_TAG
- SAMPLING_PRIORITY = INVALID_TAG
- SERVICE = INVALID_TAG
- SPAN_KIND = INVALID_TAG
- SPAN_KIND_CONSUMER = INVALID_TAG
- SPAN_KIND_PRODUCER = INVALID_TAG
- SPAN_KIND_RPC_CLIENT = INVALID_TAG
- SPAN_KIND_RPC_SERVER = INVALID_TAG
+ with scope:
+ result = yield defer.maybeDeferred(func, request, *args, **kwargs)
+ return result
+ return _trace_servlet_inner
-tags = _DummyTagNames
+ return _trace_servlet_inner_1
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2e0594e581..c4be9273f6 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -22,6 +22,7 @@ from six.moves import urllib
from twisted.internet import defer
+import synapse.logging.opentracing as opentracing
from synapse.api.errors import (
CodeMessageException,
HttpResponseException,
@@ -165,8 +166,12 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
+ headers = {}
+ opentracing.inject_active_span_byte_dict(
+ headers, None, check_destination=False
+ )
try:
- result = yield request_func(uri, data)
+ result = yield request_func(uri, data, headers=headers)
break
except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@@ -205,7 +210,14 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
- http_server.register_paths(method, [pattern], handler, self.__class__.__name__)
+ http_server.register_paths(
+ method,
+ [pattern],
+ opentracing.trace_servlet(self.__class__.__name__, extract_context=True)(
+ handler
+ ),
+ self.__class__.__name__,
+ )
def _cached_handler(self, request, txn_id, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 0dce256840..9ab1c2c9e0 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -44,6 +44,7 @@ from synapse.rest.admin._base import (
from synapse.rest.admin.media import register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
+from synapse.rest.admin.users import UserAdminServlet
from synapse.types import UserID, create_requester
from synapse.util.versionstring import get_version_string
@@ -742,6 +743,7 @@ def register_servlets(hs, http_server):
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
+ UserAdminServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
new file mode 100644
index 0000000000..b0fddb6898
--- /dev/null
+++ b/synapse/rest/admin/users.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import re
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.rest.admin import assert_requester_is_admin
+from synapse.types import UserID
+
+
+class UserAdminServlet(RestServlet):
+ """
+ Set whether or not a user is a server administrator.
+
+ Note that only local users can be server administrators, and that an
+ administrator may not demote themselves.
+
+ Only server administrators can use this API.
+
+ Example:
+ PUT /_synapse/admin/v1/users/@reivilibre:librepush.net/admin
+ {
+ "admin": true
+ }
+ """
+
+ PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>@[^/]*)/admin$"),)
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, user_id):
+ yield assert_requester_is_admin(self.auth, request)
+ requester = yield self.auth.get_user_by_req(request)
+ auth_user = requester.user
+
+ target_user = UserID.from_string(user_id)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["admin"])
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Only local users can be admins of this homeserver")
+
+ set_admin_to = bool(body["admin"])
+
+ if target_user == auth_user and not set_admin_to:
+ raise SynapseError(400, "You may not demote yourself.")
+
+ yield self.handlers.admin_handler.set_user_server_admin(
+ target_user, set_admin_to
+ )
+
+ return (200, {})
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 05ea1459e3..9510a1e2b0 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -16,7 +16,6 @@
import hmac
import logging
-from hashlib import sha1
from six import string_types
@@ -239,14 +238,12 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
- desired_password = None
if "password" in body:
if (
not isinstance(body["password"], string_types)
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
- desired_password = body["password"]
desired_username = None
if "username" in body:
@@ -261,8 +258,8 @@ class RegisterRestServlet(RestServlet):
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
- # fork off as soon as possible for ASes and shared secret auth which
- # have completely different registration flows to normal users
+ # fork off as soon as possible for ASes which have completely
+ # different registration flows to normal users
# == Application Service Registration ==
if appservice:
@@ -285,8 +282,8 @@ class RegisterRestServlet(RestServlet):
return (200, result) # we throw for non 200 responses
return
- # for either shared secret or regular registration, downcase the
- # provided username before attempting to register it. This should mean
+ # for regular registration, downcase the provided username before
+ # attempting to register it. This should mean
# that people who try to register with upper-case in their usernames
# don't get a nasty surprise. (Note that we treat username
# case-insenstively in login, so they are free to carry on imagining
@@ -294,16 +291,6 @@ class RegisterRestServlet(RestServlet):
if desired_username is not None:
desired_username = desired_username.lower()
- # == Shared Secret Registration == (e.g. create new user scripts)
- if "mac" in body:
- # FIXME: Should we really be determining if this is shared secret
- # auth based purely on the 'mac' key?
- result = yield self._do_shared_secret_registration(
- desired_username, desired_password, body
- )
- return (200, result) # we throw for non 200 responses
- return
-
# == Normal User Registration == (everyone else)
if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
@@ -513,42 +500,6 @@ class RegisterRestServlet(RestServlet):
return (yield self._create_registration_details(user_id, body))
@defer.inlineCallbacks
- def _do_shared_secret_registration(self, username, password, body):
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
- if not username:
- raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON
- )
-
- # use the username from the original request rather than the
- # downcased one in `username` for the mac calculation
- user = body["username"].encode("utf-8")
-
- # str() because otherwise hmac complains that 'unicode' does not
- # have the buffer interface
- got_mac = str(body["mac"])
-
- # FIXME this is different to the /v1/register endpoint, which
- # includes the password and admin flag in the hashed text. Why are
- # these different?
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- msg=user,
- digestmod=sha1,
- ).hexdigest()
-
- if not compare_digest(want_mac, got_mac):
- raise SynapseError(403, "HMAC incorrect")
-
- user_id = yield self.registration_handler.register_user(
- localpart=username, password=password
- )
-
- result = yield self._create_registration_details(user_id, body)
- return result
-
- @defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 031a316693..55580bc59e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,9 @@
# limitations under the License.
import logging
-from io import BytesIO
+
+from canonicaljson import encode_canonical_json, json
+from signedjson.sign import sign_json
from twisted.internet import defer
@@ -95,6 +97,7 @@ class RemoteKey(DirectServeResource):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.config = hs.config
@wrap_json_request_handler
async def _async_render_GET(self, request):
@@ -214,15 +217,14 @@ class RemoteKey(DirectServeResource):
yield self.fetcher.get_keys(cache_misses)
yield self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
- result_io = BytesIO()
- result_io.write(b'{"server_keys":')
- sep = b"["
- for json_bytes in json_results:
- result_io.write(sep)
- result_io.write(json_bytes)
- sep = b","
- if sep == b"[":
- result_io.write(sep)
- result_io.write(b"]}")
-
- respond_with_json_bytes(request, 200, result_io.getvalue())
+ signed_keys = []
+ for key_json in json_results:
+ key_json = json.loads(key_json)
+ for signing_key in self.config.key_server_signing_keys:
+ key_json = sign_json(key_json, self.config.server_name, signing_key)
+
+ signed_keys.append(key_json)
+
+ results = {"server_keys": signed_keys}
+
+ respond_with_json_bytes(request, 200, encode_canonical_json(results))
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 5e8fda4b65..20177b44e7 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -34,7 +34,7 @@ class WellKnownBuilder(object):
self._config = hs.config
def get_well_known(self):
- # if we don't have a public_base_url, we can't help much here.
+ # if we don't have a public_baseurl, we can't help much here.
if self._config.public_baseurl is None:
return None
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 8f72d92895..e11881161d 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -21,6 +21,11 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.logging.opentracing import (
+ get_active_span_text_map,
+ trace,
+ whitelisted_homeserver,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore
@@ -73,6 +78,7 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
+ @trace
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers
@@ -127,8 +133,15 @@ class DeviceWorkerStore(SQLBaseStore):
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
- # maps (user_id, device_id) -> stream_id
+ # maps (user_id, device_id) -> (stream_id, opentracing_context)
# as long as their stream_id does not match that of the last row
+ #
+ # opentracing_context contains the opentracing metadata for the request
+ # that created the poke
+ #
+ # The most recent request's opentracing_context is used as the
+ # context which created the Edu.
+
query_map = {}
for update in updates:
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
@@ -136,7 +149,14 @@ class DeviceWorkerStore(SQLBaseStore):
break
key = (update[0], update[1])
- query_map[key] = max(query_map.get(key, 0), update[2])
+
+ update_context = update[3]
+ update_stream_id = update[2]
+
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
+
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -171,7 +191,7 @@ class DeviceWorkerStore(SQLBaseStore):
List: List of device updates
"""
sql = """
- SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
+ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
ORDER BY stream_id
LIMIT ?
@@ -187,8 +207,9 @@ class DeviceWorkerStore(SQLBaseStore):
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- query_map (Dict[(str, str): int]): Dictionary mapping
- user_id/device_id to update stream_id
+ query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
+ user_id/device_id to update stream_id and the relevent json-encoded
+ opentracing context
Returns:
List[Dict]: List of objects representing an device update EDU
@@ -210,12 +231,13 @@ class DeviceWorkerStore(SQLBaseStore):
destination, user_id, from_stream_id
)
for device_id, device in iteritems(user_devices):
- stream_id = query_map[(user_id, device_id)]
+ stream_id, opentracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
+ "org.matrix.opentracing_context": opentracing_context,
}
prev_id = stream_id
@@ -814,6 +836,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
],
)
+ context = get_active_span_text_map()
+
self._simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -825,6 +849,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"device_id": device_id,
"sent": False,
"ts": now,
+ "opentracing_context": json.dumps(context)
+ if whitelisted_homeserver(destination)
+ else None,
}
for destination in hosts
for device_id in device_ids
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d20eacda59..e96eed8a6d 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -238,6 +238,13 @@ def _upgrade_existing_database(
logger.debug("applied_delta_files: %s", applied_delta_files)
+ if isinstance(database_engine, PostgresEngine):
+ specific_engine_extension = ".postgres"
+ else:
+ specific_engine_extension = ".sqlite"
+
+ specific_engine_extensions = (".sqlite", ".postgres")
+
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.info("Upgrading schema to v%d", v)
@@ -274,15 +281,22 @@ def _upgrade_existing_database(
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
- pass
+ continue
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.info("Applying schema %s", relative_path)
executescript(cur, absolute_path)
+ elif ext == specific_engine_extension and root_name.endswith(".sql"):
+ # A .sql file specific to our engine; just read and execute it
+ logger.info("Applying engine-specific schema %s", relative_path)
+ executescript(cur, absolute_path)
+ elif ext in specific_engine_extensions and root_name.endswith(".sql"):
+ # A .sql file for a different engine; skip it.
+ continue
else:
# Not a valid delta file.
- logger.warn(
- "Found directory entry that did not end in .py or" " .sql: %s",
+ logger.warning(
+ "Found directory entry that did not end in .py or .sql: %s",
relative_path,
)
continue
@@ -290,7 +304,7 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
+ "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
),
(v, relative_path),
)
@@ -298,7 +312,7 @@ def _upgrade_existing_database(
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(v, True),
)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 55e4e84d71..9027b917c1 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -272,6 +272,14 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def is_server_admin(self, user):
+ """Determines if a user is an admin of this homeserver.
+
+ Args:
+ user (UserID): user ID of the user to test
+
+ Returns (bool):
+ true iff the user is a server admin, false otherwise.
+ """
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
@@ -282,6 +290,21 @@ class RegistrationWorkerStore(SQLBaseStore):
return res if res else False
+ def set_server_admin(self, user, admin):
+ """Sets whether a user is an admin of this homeserver.
+
+ Args:
+ user (UserID): user ID of the user to test
+ admin (bool): true iff the user is to be a server admin,
+ false otherwise.
+ """
+ return self._simple_update_one(
+ table="users",
+ keyvalues={"name": user.to_string()},
+ updatevalues={"admin": 1 if admin else 0},
+ desc="set_server_admin",
+ )
+
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
diff --git a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql
new file mode 100644
index 0000000000..41807eb1e7
--- /dev/null
+++ b/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Opentracing context data for inclusion in the device_list_update EDUs, as a
+ * json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination).
+ */
+ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT;
|