summary refs log tree commit diff
path: root/tests/rest/client/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/utils.py')
-rw-r--r--tests/rest/client/utils.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index a0788b1bb0..93f749744d 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -41,6 +41,7 @@ from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.api.errors import Codes
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 
@@ -171,6 +172,8 @@ class RestHelper:
         expect_code: int = HTTPStatus.OK,
         tok: Optional[str] = None,
         appservice_user_id: Optional[str] = None,
+        expect_errcode: Optional[Codes] = None,
+        expect_additional_fields: Optional[dict] = None,
     ) -> None:
         self.change_membership(
             room=room,
@@ -180,6 +183,8 @@ class RestHelper:
             appservice_user_id=appservice_user_id,
             membership=Membership.JOIN,
             expect_code=expect_code,
+            expect_errcode=expect_errcode,
+            expect_additional_fields=expect_additional_fields,
         )
 
     def knock(
@@ -263,6 +268,7 @@ class RestHelper:
         appservice_user_id: Optional[str] = None,
         expect_code: int = HTTPStatus.OK,
         expect_errcode: Optional[str] = None,
+        expect_additional_fields: Optional[dict] = None,
     ) -> None:
         """
         Send a membership state event into a room.
@@ -323,6 +329,21 @@ class RestHelper:
                 channel.result["body"],
             )
 
+        if expect_additional_fields is not None:
+            for expect_key, expect_value in expect_additional_fields.items():
+                assert expect_key in channel.json_body, "Expected field %s, got %s" % (
+                    expect_key,
+                    channel.json_body,
+                )
+                assert (
+                    channel.json_body[expect_key] == expect_value
+                ), "Expected: %s at %s, got: %s, resp: %s" % (
+                    expect_value,
+                    expect_key,
+                    channel.json_body[expect_key],
+                    channel.json_body,
+                )
+
         self.auth_user_id = temp_id
 
     def send(