summary refs log tree commit diff
path: root/tests/rest/client/v1
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/v1')
-rw-r--r--tests/rest/client/v1/utils.py96
1 files changed, 92 insertions, 4 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 873d5ef99c..371637618d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -18,6 +18,7 @@
 
 import json
 import time
+from typing import Any, Dict, Optional
 
 import attr
 
@@ -142,7 +143,34 @@ class RestHelper(object):
 
         return channel.json_body
 
-    def send_state(self, room_id, event_type, body, tok, expect_code=200, state_key=""):
+    def _read_write_state(
+        self,
+        room_id: str,
+        event_type: str,
+        body: Optional[Dict[str, Any]],
+        tok: str,
+        expect_code: int = 200,
+        state_key: str = "",
+        method: str = "GET",
+    ) -> Dict:
+        """Read or write some state from a given room
+
+        Args:
+            room_id:
+            event_type: The type of state event
+            body: Body that is sent when making the request. The content of the state event.
+                If None, the request to the server will have an empty body
+            tok: The access token to use
+            expect_code: The HTTP code to expect in the response
+            state_key:
+            method: "GET" or "PUT" for reading or writing state, respectively
+
+        Returns:
+            The response body from the server
+
+        Raises:
+            AssertionError: if expect_code doesn't match the HTTP code we received
+        """
         path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
             room_id,
             event_type,
@@ -151,9 +179,13 @@ class RestHelper(object):
         if tok:
             path = path + "?access_token=%s" % tok
 
-        request, channel = make_request(
-            self.hs.get_reactor(), "PUT", path, json.dumps(body).encode("utf8")
-        )
+        # Set request body if provided
+        content = b""
+        if body is not None:
+            content = json.dumps(body).encode("utf8")
+
+        request, channel = make_request(self.hs.get_reactor(), method, path, content)
+
         render(request, self.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
@@ -163,6 +195,62 @@ class RestHelper(object):
 
         return channel.json_body
 
+    def get_state(
+        self,
+        room_id: str,
+        event_type: str,
+        tok: str,
+        expect_code: int = 200,
+        state_key: str = "",
+    ):
+        """Gets some state from a room
+
+        Args:
+            room_id:
+            event_type: The type of state event
+            tok: The access token to use
+            expect_code: The HTTP code to expect in the response
+            state_key:
+
+        Returns:
+            The response body from the server
+
+        Raises:
+            AssertionError: if expect_code doesn't match the HTTP code we received
+        """
+        return self._read_write_state(
+            room_id, event_type, None, tok, expect_code, state_key, method="GET"
+        )
+
+    def send_state(
+        self,
+        room_id: str,
+        event_type: str,
+        body: Dict[str, Any],
+        tok: str,
+        expect_code: int = 200,
+        state_key: str = "",
+    ):
+        """Set some state in a room
+
+        Args:
+            room_id:
+            event_type: The type of state event
+            body: Body that is sent when making the request. The content of the state event.
+            tok: The access token to use
+            expect_code: The HTTP code to expect in the response
+            state_key:
+
+        Returns:
+            The response body from the server
+
+        Raises:
+            AssertionError: if expect_code doesn't match the HTTP code we received
+        """
+        return self._read_write_state(
+            room_id, event_type, body, tok, expect_code, state_key, method="PUT"
+        )
+
     def upload_media(
         self,
         resource: Resource,