diff options
-rw-r--r-- | tests/rest/client/v2_alpha/test_relations.py | 80 |
1 files changed, 68 insertions, 12 deletions
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index b0e4c47ae3..cd965167f8 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -19,19 +19,22 @@ import json import six from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest import admin from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import relations +from synapse.rest.client.v2_alpha import register, relations from tests import unittest class RelationsTestCase(unittest.HomeserverTestCase): - user_id = "@alice:test" servlets = [ relations.register_servlets, room.register_servlets, login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, ] + hijack_auth = False def make_homeserver(self, reactor, clock): # We need to enable msc1849 support for aggregations @@ -40,8 +43,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): - self.room = self.helper.create_room_as(self.user_id) - res = self.helper.send(self.room, body="Hi!") + self.user_id, self.user_token = self._create_user("alice") + self.user2_id, self.user2_token = self._create_user("bob") + + self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) + self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) + res = self.helper.send(self.room, body="Hi!", tok=self.user_token) self.parent_id = res["event_id"] def test_send_relation(self): @@ -55,7 +62,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): event_id = channel.json_body["event_id"] request, channel = self.make_request( - "GET", "/rooms/%s/event/%s" % (self.room, event_id) + "GET", + "/rooms/%s/event/%s" % (self.room, event_id), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -95,6 +104,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" % (self.room, self.parent_id), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -135,6 +145,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" % (self.room, self.parent_id, from_token), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -156,15 +167,32 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Test that we can paginate annotation groups correctly. """ + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} for key in itertools.chain.from_iterable( itertools.repeat(key, num) for key, num in sent_groups.items() ): channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key=key + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], ) self.assertEquals(200, channel.code, channel.json_body) + idx += 1 + idx %= len(access_tokens) + prev_token = None found_groups = {} for _ in range(20): @@ -176,6 +204,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" % (self.room, self.parent_id, from_token), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -236,6 +265,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): encoded_key, from_token, ), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -263,7 +293,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") @@ -273,6 +305,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s" % (self.room, self.parent_id), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -295,6 +328,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1" % (self.room, self.parent_id), + access_token=self.user_token, ) self.render(request) self.assertEquals(400, channel.code, channel.json_body) @@ -307,7 +341,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) self.assertEquals(200, channel.code, channel.json_body) channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") @@ -322,7 +358,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): reply_2 = channel.json_body["event_id"] request, channel = self.make_request( - "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -357,7 +395,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] request, channel = self.make_request( - "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -407,7 +447,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) request, channel = self.make_request( - "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id) + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, ) self.render(request) self.assertEquals(200, channel.code, channel.json_body) @@ -419,7 +461,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): {RelationTypes.REPLACES: {"event_id": edit_event_id}}, ) - def _send_relation(self, relation_type, event_type, key=None, content={}): + def _send_relation( + self, relation_type, event_type, key=None, content={}, access_token=None + ): """Helper function to send a relation pointing at `self.parent_id` Args: @@ -428,10 +472,15 @@ class RelationsTestCase(unittest.HomeserverTestCase): key (str|None): The aggregation key used for m.annotation relation type. content(dict|None): The content of the created event. + access_token (str|None): The access token used to send the relation, + defaults to `self.user_token` Returns: FakeChannel """ + if not access_token: + access_token = self.user_token + query = "" if key: query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) @@ -441,6 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, self.parent_id, relation_type, event_type, query), json.dumps(content).encode("utf-8"), + access_token=access_token, ) self.render(request) return channel + + def _create_user(self, localpart): + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token |