summary refs log tree commit diff
diff options
context:
space:
mode:
authorMarcel <MTRNord@users.noreply.github.com>2023-09-06 21:19:17 +0200
committerGitHub <noreply@github.com>2023-09-06 15:19:17 -0400
commit13e9cad537a16108b0cb544ccdc24e7dc2ca33ae (patch)
treea451a244a2cf5f5ac6dfa9e618bdc8c8431ddb2f
parentHandle "registration_enabled" parameter for CAS (#16262) (diff)
downloadsynapse-13e9cad537a16108b0cb544ccdc24e7dc2ca33ae.tar.xz
Send the opentracing span information to appservices (#16227)
-rw-r--r--changelog.d/16227.feature1
-rw-r--r--synapse/appservice/api.py32
-rw-r--r--tests/appservice/test_api.py18
3 files changed, 37 insertions, 14 deletions
diff --git a/changelog.d/16227.feature b/changelog.d/16227.feature
new file mode 100644
index 0000000000..510062b622
--- /dev/null
+++ b/changelog.d/16227.feature
@@ -0,0 +1 @@
+Add span information to requests sent to appservices. Contributed by MTRNord.
\ No newline at end of file
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index de7a94bf26..b1523be208 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -40,6 +40,7 @@ from synapse.appservice import (
 from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig, serialize_event
 from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
+from synapse.logging import opentracing
 from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -125,6 +126,17 @@ class ApplicationServiceApi(SimpleHttpClient):
             hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
         )
 
+    def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]:
+        """This makes sure we have always the auth header and opentracing headers set."""
+
+        # This is also ensured before in the functions. However this is needed to please
+        # the typechecks.
+        assert service.hs_token is not None
+
+        headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]}
+        opentracing.inject_header_dict(headers, check_destination=False)
+        return headers
+
     async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
         if service.url is None:
             return False
@@ -136,10 +148,11 @@ class ApplicationServiceApi(SimpleHttpClient):
             args = None
             if self.config.use_appservice_legacy_authorization:
                 args = {"access_token": service.hs_token}
+
             response = await self.get_json(
                 f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}",
                 args,
-                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+                headers=self._get_headers(service),
             )
             if response is not None:  # just an empty json object
                 return True
@@ -162,10 +175,11 @@ class ApplicationServiceApi(SimpleHttpClient):
             args = None
             if self.config.use_appservice_legacy_authorization:
                 args = {"access_token": service.hs_token}
+
             response = await self.get_json(
                 f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}",
                 args,
-                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+                headers=self._get_headers(service),
             )
             if response is not None:  # just an empty json object
                 return True
@@ -203,10 +217,11 @@ class ApplicationServiceApi(SimpleHttpClient):
                     **fields,
                     b"access_token": service.hs_token,
                 }
+
             response = await self.get_json(
                 f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}",
                 args=args,
-                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+                headers=self._get_headers(service),
             )
             if not isinstance(response, list):
                 logger.warning(
@@ -243,10 +258,11 @@ class ApplicationServiceApi(SimpleHttpClient):
                 args = None
                 if self.config.use_appservice_legacy_authorization:
                     args = {"access_token": service.hs_token}
+
                 info = await self.get_json(
                     f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}",
                     args,
-                    headers={"Authorization": [f"Bearer {service.hs_token}"]},
+                    headers=self._get_headers(service),
                 )
 
                 if not _is_valid_3pe_metadata(info):
@@ -283,7 +299,7 @@ class ApplicationServiceApi(SimpleHttpClient):
         await self.post_json_get_json(
             uri=f"{service.url}{APP_SERVICE_PREFIX}/ping",
             post_json={"transaction_id": txn_id},
-            headers={"Authorization": [f"Bearer {service.hs_token}"]},
+            headers=self._get_headers(service),
         )
 
     async def push_bulk(
@@ -364,7 +380,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}",
                 json_body=body,
                 args=args,
-                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+                headers=self._get_headers(service),
             )
             if logger.isEnabledFor(logging.DEBUG):
                 logger.debug(
@@ -437,7 +453,7 @@ class ApplicationServiceApi(SimpleHttpClient):
             response = await self.post_json_get_json(
                 uri,
                 body,
-                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+                headers=self._get_headers(service),
             )
         except HttpResponseException as e:
             # The appservice doesn't support this endpoint.
@@ -498,7 +514,7 @@ class ApplicationServiceApi(SimpleHttpClient):
             response = await self.post_json_get_json(
                 uri,
                 query,
-                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+                headers=self._get_headers(service),
             )
         except HttpResponseException as e:
             # The appservice doesn't support this endpoint.
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 75fb5fae6b..366b6fd5f0 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -76,7 +76,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
             headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
         ) -> List[JsonDict]:
             # Ensure the access token is passed as a header.
-            if not headers or not headers.get("Authorization"):
+            if not headers or not headers.get(b"Authorization"):
                 raise RuntimeError("Access token not provided")
             # ... and not as a query param
             if b"access_token" in args:
@@ -84,7 +84,9 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
                     "Access token should not be passed as a query param."
                 )
 
-            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+            self.assertEqual(
+                headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+            )
             self.request_url = url
             if url == URL_USER:
                 return SUCCESS_RESULT_USER
@@ -152,11 +154,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
             # Ensure the access token is passed as a both a query param and in the headers.
             if not args.get(b"access_token"):
                 raise RuntimeError("Access token should be provided in query params.")
-            if not headers or not headers.get("Authorization"):
+            if not headers or not headers.get(b"Authorization"):
                 raise RuntimeError("Access token should be provided in auth headers.")
 
             self.assertEqual(args.get(b"access_token"), TOKEN)
-            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+            self.assertEqual(
+                headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+            )
             self.request_url = url
             if url == URL_USER:
                 return SUCCESS_RESULT_USER
@@ -208,10 +212,12 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
             headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
         ) -> JsonDict:
             # Ensure the access token is passed as both a header and query arg.
-            if not headers.get("Authorization"):
+            if not headers.get(b"Authorization"):
                 raise RuntimeError("Access token not provided")
 
-            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+            self.assertEqual(
+                headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
+            )
             return RESPONSE
 
         # We assign to a method, which mypy doesn't like.