summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/federation_server.py42
-rw-r--r--synapse/federation/sender/transaction_manager.py193
-rw-r--r--synapse/federation/transport/client.py46
-rw-r--r--synapse/federation/transport/server.py88
-rw-r--r--synapse/federation/units.py6
6 files changed, 247 insertions, 136 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index bec3080895..6ee6216660 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -355,7 +355,7 @@ class FederationClient(FederationBase):
 
             auth_chain.sort(key=lambda e: e.depth)
 
-            return (pdus, auth_chain)
+            return pdus, auth_chain
         except HttpResponseException as e:
             if e.code == 400 or e.code == 404:
                 logger.info("Failed to use get_room_state_ids API, falling back")
@@ -404,7 +404,7 @@ class FederationClient(FederationBase):
 
         signed_auth.sort(key=lambda e: e.depth)
 
-        return (signed_pdus, signed_auth)
+        return signed_pdus, signed_auth
 
     @defer.inlineCallbacks
     def get_events_from_store_or_dest(self, destination, room_id, event_ids):
@@ -429,7 +429,7 @@ class FederationClient(FederationBase):
             missing_events.discard(k)
 
         if not missing_events:
-            return (signed_events, failed_to_fetch)
+            return signed_events, failed_to_fetch
 
         logger.debug(
             "Fetching unknown state/auth events %s for room %s",
@@ -465,7 +465,7 @@ class FederationClient(FederationBase):
             # We removed all events we successfully fetched from `batch`
             failed_to_fetch.update(batch)
 
-        return (signed_events, failed_to_fetch)
+        return signed_events, failed_to_fetch
 
     @defer.inlineCallbacks
     @log_function
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d216c46dfe..da06ab379d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -43,6 +43,7 @@ from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
 from synapse.http.endpoint import parse_server_name
 from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
 from synapse.logging.utils import log_function
 from synapse.replication.http.federation import (
     ReplicationFederationSendEduRestServlet,
@@ -99,7 +100,7 @@ class FederationServer(FederationBase):
 
             res = self._transaction_from_pdus(pdus).get_dict()
 
-        return (200, res)
+        return 200, res
 
     @defer.inlineCallbacks
     @log_function
@@ -162,7 +163,7 @@ class FederationServer(FederationBase):
             yield self.transaction_actions.set_response(
                 origin, transaction, 400, response
             )
-            return (400, response)
+            return 400, response
 
         received_pdus_counter.inc(len(transaction.pdus))
 
@@ -264,7 +265,7 @@ class FederationServer(FederationBase):
         logger.debug("Returning: %s", str(response))
 
         yield self.transaction_actions.set_response(origin, transaction, 200, response)
-        return (200, response)
+        return 200, response
 
     @defer.inlineCallbacks
     def received_edu(self, origin, edu_type, content):
@@ -297,7 +298,7 @@ class FederationServer(FederationBase):
                 event_id,
             )
 
-        return (200, resp)
+        return 200, resp
 
     @defer.inlineCallbacks
     def on_state_ids_request(self, origin, room_id, event_id):
@@ -314,7 +315,7 @@ class FederationServer(FederationBase):
         state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
         auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
 
-        return (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
+        return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
 
     @defer.inlineCallbacks
     def _on_context_state_request_compute(self, room_id, event_id):
@@ -344,15 +345,15 @@ class FederationServer(FederationBase):
         pdu = yield self.handler.get_persisted_pdu(origin, event_id)
 
         if pdu:
-            return (200, self._transaction_from_pdus([pdu]).get_dict())
+            return 200, self._transaction_from_pdus([pdu]).get_dict()
         else:
-            return (404, "")
+            return 404, ""
 
     @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
         received_queries_counter.labels(query_type).inc()
         resp = yield self.registry.on_query(query_type, args)
-        return (200, resp)
+        return 200, resp
 
     @defer.inlineCallbacks
     def on_make_join_request(self, origin, room_id, user_id, supported_versions):
@@ -434,7 +435,7 @@ class FederationServer(FederationBase):
 
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
         yield self.handler.on_send_leave_request(origin, pdu)
-        return (200, {})
+        return 200, {}
 
     @defer.inlineCallbacks
     def on_event_auth(self, origin, room_id, event_id):
@@ -445,7 +446,7 @@ class FederationServer(FederationBase):
             time_now = self._clock.time_msec()
             auth_pdus = yield self.handler.on_event_auth(event_id)
             res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
-        return (200, res)
+        return 200, res
 
     @defer.inlineCallbacks
     def on_query_auth_request(self, origin, content, room_id, event_id):
@@ -498,7 +499,7 @@ class FederationServer(FederationBase):
                 "missing": ret.get("missing", []),
             }
 
-        return (200, send_content)
+        return 200, send_content
 
     @log_function
     def on_query_client_keys(self, origin, content):
@@ -507,6 +508,7 @@ class FederationServer(FederationBase):
     def on_query_user_devices(self, origin, user_id):
         return self.on_query_request("user_devices", user_id)
 
+    @trace
     @defer.inlineCallbacks
     @log_function
     def on_claim_client_keys(self, origin, content):
@@ -515,6 +517,7 @@ class FederationServer(FederationBase):
             for device_id, algorithm in device_keys.items():
                 query.append((user_id, device_id, algorithm))
 
+        log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
         results = yield self.store.claim_e2e_one_time_keys(query)
 
         json_result = {}
@@ -666,9 +669,9 @@ class FederationServer(FederationBase):
         return ret
 
     @defer.inlineCallbacks
-    def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+    def on_exchange_third_party_invite_request(self, room_id, event_dict):
         ret = yield self.handler.on_exchange_third_party_invite_request(
-            origin, room_id, event_dict
+            room_id, event_dict
         )
         return ret
 
@@ -808,12 +811,13 @@ class FederationHandlerRegistry(object):
         if not handler:
             logger.warn("No handler registered for EDU type %s", edu_type)
 
-        try:
-            yield handler(origin, content)
-        except SynapseError as e:
-            logger.info("Failed to handle edu %r: %r", edu_type, e)
-        except Exception:
-            logger.exception("Failed to handle edu %r", edu_type)
+        with start_active_span_from_edu(content, "handle_edu"):
+            try:
+                yield handler(origin, content)
+            except SynapseError as e:
+                logger.info("Failed to handle edu %r: %r", edu_type, e)
+            except Exception:
+                logger.exception("Failed to handle edu %r", edu_type)
 
     def on_query(self, query_type, args):
         handler = self.query_handlers.get(query_type)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 52706302f2..5b6c79c51a 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -14,11 +14,20 @@
 # limitations under the License.
 import logging
 
+from canonicaljson import json
+
 from twisted.internet import defer
 
 from synapse.api.errors import HttpResponseException
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Transaction
+from synapse.logging.opentracing import (
+    extract_text_map,
+    set_tag,
+    start_active_span_follows_from,
+    tags,
+    whitelisted_homeserver,
+)
 from synapse.util.metrics import measure_func
 
 logger = logging.getLogger(__name__)
@@ -44,93 +53,115 @@ class TransactionManager(object):
     @defer.inlineCallbacks
     def send_new_transaction(self, destination, pending_pdus, pending_edus):
 
-        # Sort based on the order field
-        pending_pdus.sort(key=lambda t: t[1])
-        pdus = [x[0] for x in pending_pdus]
-        edus = pending_edus
-
-        success = True
-
-        logger.debug("TX [%s] _attempt_new_transaction", destination)
-
-        txn_id = str(self._next_txn_id)
-
-        logger.debug(
-            "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
-            destination,
-            txn_id,
-            len(pdus),
-            len(edus),
-        )
-
-        transaction = Transaction.create_new(
-            origin_server_ts=int(self.clock.time_msec()),
-            transaction_id=txn_id,
-            origin=self._server_name,
-            destination=destination,
-            pdus=pdus,
-            edus=edus,
-        )
-
-        self._next_txn_id += 1
-
-        logger.info(
-            "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
-            destination,
-            txn_id,
-            transaction.transaction_id,
-            len(pdus),
-            len(edus),
-        )
-
-        # Actually send the transaction
-
-        # FIXME (erikj): This is a bit of a hack to make the Pdu age
-        # keys work
-        def json_data_cb():
-            data = transaction.get_dict()
-            now = int(self.clock.time_msec())
-            if "pdus" in data:
-                for p in data["pdus"]:
-                    if "age_ts" in p:
-                        unsigned = p.setdefault("unsigned", {})
-                        unsigned["age"] = now - int(p["age_ts"])
-                        del p["age_ts"]
-            return data
-
-        try:
-            response = yield self._transport_layer.send_transaction(
-                transaction, json_data_cb
+        # Make a transaction-sending opentracing span. This span follows on from
+        # all the edus in that transaction. This needs to be done since there is
+        # no active span here, so if the edus were not received by the remote the
+        # span would have no causality and it would be forgotten.
+        # The span_contexts is a generator so that it won't be evaluated if
+        # opentracing is disabled. (Yay speed!)
+
+        span_contexts = []
+        keep_destination = whitelisted_homeserver(destination)
+
+        for edu in pending_edus:
+            context = edu.get_context()
+            if context:
+                span_contexts.append(extract_text_map(json.loads(context)))
+            if keep_destination:
+                edu.strip_context()
+
+        with start_active_span_follows_from("send_transaction", span_contexts):
+
+            # Sort based on the order field
+            pending_pdus.sort(key=lambda t: t[1])
+            pdus = [x[0] for x in pending_pdus]
+            edus = pending_edus
+
+            success = True
+
+            logger.debug("TX [%s] _attempt_new_transaction", destination)
+
+            txn_id = str(self._next_txn_id)
+
+            logger.debug(
+                "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
+                destination,
+                txn_id,
+                len(pdus),
+                len(edus),
+            )
+
+            transaction = Transaction.create_new(
+                origin_server_ts=int(self.clock.time_msec()),
+                transaction_id=txn_id,
+                origin=self._server_name,
+                destination=destination,
+                pdus=pdus,
+                edus=edus,
             )
-            code = 200
-        except HttpResponseException as e:
-            code = e.code
-            response = e.response
 
-            if e.code in (401, 404, 429) or 500 <= e.code:
-                logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
-                raise e
+            self._next_txn_id += 1
 
-        logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+            logger.info(
+                "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
+                destination,
+                txn_id,
+                transaction.transaction_id,
+                len(pdus),
+                len(edus),
+            )
 
-        if code == 200:
-            for e_id, r in response.get("pdus", {}).items():
-                if "error" in r:
+            # Actually send the transaction
+
+            # FIXME (erikj): This is a bit of a hack to make the Pdu age
+            # keys work
+            def json_data_cb():
+                data = transaction.get_dict()
+                now = int(self.clock.time_msec())
+                if "pdus" in data:
+                    for p in data["pdus"]:
+                        if "age_ts" in p:
+                            unsigned = p.setdefault("unsigned", {})
+                            unsigned["age"] = now - int(p["age_ts"])
+                            del p["age_ts"]
+                return data
+
+            try:
+                response = yield self._transport_layer.send_transaction(
+                    transaction, json_data_cb
+                )
+                code = 200
+            except HttpResponseException as e:
+                code = e.code
+                response = e.response
+
+                if e.code in (401, 404, 429) or 500 <= e.code:
+                    logger.info(
+                        "TX [%s] {%s} got %d response", destination, txn_id, code
+                    )
+                    raise e
+
+            logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+
+            if code == 200:
+                for e_id, r in response.get("pdus", {}).items():
+                    if "error" in r:
+                        logger.warn(
+                            "TX [%s] {%s} Remote returned error for %s: %s",
+                            destination,
+                            txn_id,
+                            e_id,
+                            r,
+                        )
+            else:
+                for p in pdus:
                     logger.warn(
-                        "TX [%s] {%s} Remote returned error for %s: %s",
+                        "TX [%s] {%s} Failed to send event %s",
                         destination,
                         txn_id,
-                        e_id,
-                        r,
+                        p.event_id,
                     )
-        else:
-            for p in pdus:
-                logger.warn(
-                    "TX [%s] {%s} Failed to send event %s",
-                    destination,
-                    txn_id,
-                    p.event_id,
-                )
-            success = False
+                success = False
 
-        return success
+            set_tag(tags.ERROR, not success)
+            return success
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 0cea0d2a10..482a101c09 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -327,21 +327,37 @@ class TransportLayerClient(object):
         include_all_networks=False,
         third_party_instance_id=None,
     ):
-        path = _create_v1_path("/publicRooms")
-
-        args = {"include_all_networks": "true" if include_all_networks else "false"}
-        if third_party_instance_id:
-            args["third_party_instance_id"] = (third_party_instance_id,)
-        if limit:
-            args["limit"] = [str(limit)]
-        if since_token:
-            args["since"] = [since_token]
-
-        # TODO(erikj): Actually send the search_filter across federation.
-
-        response = yield self.client.get_json(
-            destination=remote_server, path=path, args=args, ignore_backoff=True
-        )
+        if search_filter:
+            # this uses MSC2197 (Search Filtering over Federation)
+            path = _create_v1_path("/publicRooms")
+
+            data = {"include_all_networks": "true" if include_all_networks else "false"}
+            if third_party_instance_id:
+                data["third_party_instance_id"] = third_party_instance_id
+            if limit:
+                data["limit"] = str(limit)
+            if since_token:
+                data["since"] = since_token
+
+            data["filter"] = search_filter
+
+            response = yield self.client.post_json(
+                destination=remote_server, path=path, data=data, ignore_backoff=True
+            )
+        else:
+            path = _create_v1_path("/publicRooms")
+
+            args = {"include_all_networks": "true" if include_all_networks else "false"}
+            if third_party_instance_id:
+                args["third_party_instance_id"] = (third_party_instance_id,)
+            if limit:
+                args["limit"] = [str(limit)]
+            if since_token:
+                args["since"] = [since_token]
+
+            response = yield self.client.get_json(
+                destination=remote_server, path=path, args=args, ignore_backoff=True
+            )
 
         return response
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 9a86bd0263..7f8a16e355 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -22,7 +22,6 @@ import re
 from twisted.internet.defer import maybeDeferred
 
 import synapse
-import synapse.logging.opentracing as opentracing
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.api.urls import (
@@ -39,6 +38,12 @@ from synapse.http.servlet import (
     parse_string_from_args,
 )
 from synapse.logging.context import run_in_background
+from synapse.logging.opentracing import (
+    start_active_span,
+    start_active_span_from_request,
+    tags,
+    whitelisted_homeserver,
+)
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
@@ -160,7 +165,7 @@ class Authenticator(object):
     async def _reset_retry_timings(self, origin):
         try:
             logger.info("Marking origin %r as up", origin)
-            await self.store.set_destination_retry_timings(origin, 0, 0)
+            await self.store.set_destination_retry_timings(origin, None, 0, 0)
         except Exception:
             logger.exception("Error resetting retry timings on %s", origin)
 
@@ -288,19 +293,28 @@ class BaseFederationServlet(object):
                 logger.warn("authenticate_request failed: %s", e)
                 raise
 
-            # Start an opentracing span
-            with opentracing.start_active_span_from_context(
-                request.requestHeaders,
-                "incoming-federation-request",
-                tags={
-                    "request_id": request.get_request_id(),
-                    opentracing.tags.SPAN_KIND: opentracing.tags.SPAN_KIND_RPC_SERVER,
-                    opentracing.tags.HTTP_METHOD: request.get_method(),
-                    opentracing.tags.HTTP_URL: request.get_redacted_uri(),
-                    opentracing.tags.PEER_HOST_IPV6: request.getClientIP(),
-                    "authenticated_entity": origin,
-                },
-            ):
+            request_tags = {
+                "request_id": request.get_request_id(),
+                tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+                tags.HTTP_METHOD: request.get_method(),
+                tags.HTTP_URL: request.get_redacted_uri(),
+                tags.PEER_HOST_IPV6: request.getClientIP(),
+                "authenticated_entity": origin,
+                "servlet_name": request.request_metrics.name,
+            }
+
+            # Only accept the span context if the origin is authenticated
+            # and whitelisted
+            if origin and whitelisted_homeserver(origin):
+                scope = start_active_span_from_request(
+                    request, "incoming-federation-request", tags=request_tags
+                )
+            else:
+                scope = start_active_span(
+                    "incoming-federation-request", tags=request_tags
+                )
+
+            with scope:
                 if origin:
                     with ratelimiter.ratelimit(origin) as d:
                         await d
@@ -328,7 +342,11 @@ class BaseFederationServlet(object):
                 continue
 
             server.register_paths(
-                method, (pattern,), self._wrap(code), self.__class__.__name__
+                method,
+                (pattern,),
+                self._wrap(code),
+                self.__class__.__name__,
+                trace=False,
             )
 
 
@@ -557,7 +575,7 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
 
     async def on_PUT(self, origin, content, query, room_id):
         content = await self.handler.on_exchange_third_party_invite_request(
-            origin, room_id, content
+            room_id, content
         )
         return 200, content
 
@@ -756,6 +774,42 @@ class PublicRoomList(BaseFederationServlet):
         )
         return 200, data
 
+    async def on_POST(self, origin, content, query):
+        # This implements MSC2197 (Search Filtering over Federation)
+        if not self.allow_access:
+            raise FederationDeniedError(origin)
+
+        limit = int(content.get("limit", 100))
+        since_token = content.get("since", None)
+        search_filter = content.get("filter", None)
+
+        include_all_networks = content.get("include_all_networks", False)
+        third_party_instance_id = content.get("third_party_instance_id", None)
+
+        if include_all_networks:
+            network_tuple = None
+            if third_party_instance_id is not None:
+                raise SynapseError(
+                    400, "Can't use include_all_networks with an explicit network"
+                )
+        elif third_party_instance_id is None:
+            network_tuple = ThirdPartyInstanceID(None, None)
+        else:
+            network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+
+        if search_filter is None:
+            logger.warning("Nonefilter")
+
+        data = await self.handler.get_local_public_room_list(
+            limit=limit,
+            since_token=since_token,
+            search_filter=search_filter,
+            network_tuple=network_tuple,
+            from_federation=True,
+        )
+
+        return 200, data
+
 
 class FederationVersionServlet(BaseFederationServlet):
     PATH = "/version"
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 14aad8f09d..b4d743cde7 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -38,6 +38,12 @@ class Edu(JsonEncodedObject):
 
     internal_keys = ["origin", "destination"]
 
+    def get_context(self):
+        return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
+
+    def strip_context(self):
+        getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
+
 
 class Transaction(JsonEncodedObject):
     """ A transaction is a list of Pdus and Edus to be sent to a remote home