summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/v1/utils.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 69798e95c3..fc2d35596e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -19,7 +19,7 @@ import json
 import re
 import time
 import urllib.parse
-from typing import Any, Dict, Mapping, MutableMapping, Optional
+from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
 from unittest.mock import patch
 
 import attr
@@ -53,6 +53,9 @@ class RestHelper:
         tok: str = None,
         expect_code: int = 200,
         extra_content: Optional[Dict] = None,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
     ) -> str:
         """
         Create a room.
@@ -87,6 +90,7 @@ class RestHelper:
             "POST",
             path,
             json.dumps(content).encode("utf8"),
+            custom_headers=custom_headers,
         )
 
         assert channel.result["code"] == b"%d" % expect_code, channel.result
@@ -175,14 +179,30 @@ class RestHelper:
 
         self.auth_user_id = temp_id
 
-    def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+    def send(
+        self,
+        room_id,
+        body=None,
+        txn_id=None,
+        tok=None,
+        expect_code=200,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
+    ):
         if body is None:
             body = "body_text_here"
 
         content = {"msgtype": "m.text", "body": body}
 
         return self.send_event(
-            room_id, "m.room.message", content, txn_id, tok, expect_code
+            room_id,
+            "m.room.message",
+            content,
+            txn_id,
+            tok,
+            expect_code,
+            custom_headers=custom_headers,
         )
 
     def send_event(
@@ -193,6 +213,9 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
     ):
         if txn_id is None:
             txn_id = "m%s" % (str(time.time()))
@@ -207,6 +230,7 @@ class RestHelper:
             "PUT",
             path,
             json.dumps(content or {}).encode("utf8"),
+            custom_headers=custom_headers,
         )
 
         assert (