diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3c79d4afe7..4fba462d44 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -24,7 +24,12 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import (
+ KeyLookupError,
+ PerspectivesKeyFetcher,
+ ServerKeyFetcher,
+)
+from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -80,7 +85,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
@@ -89,7 +94,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
@@ -192,8 +197,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_key(
- "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
+ key1_id = "%s:%s" % (key1.alg, key1.version)
+
+ r = self.hs.datastore.store_server_verify_keys(
+ "server9",
+ time.time() * 1000,
+ [
+ (
+ "server9",
+ key1_id,
+ FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
+ ),
+ ],
)
self.get_success(r)
json1 = {}
@@ -207,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(d.called)
self.get_success(d)
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ return hs
+
def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
+ fetcher = ServerKeyFetcher(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -239,11 +261,12 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+ keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -266,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER"
self.get_failure(
- kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+ fetcher.get_keys(server_name_and_key_ids), KeyLookupError
)
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.mock_perspective_server = MockPerspectiveServer()
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ keys = self.mock_perspective_server.get_verify_keys()
+ hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ return hs
+
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -308,12 +342,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+ keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -336,7 +371,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx"):
+ with LoggingContext("testctx") as ctx:
+ # we set the "request" prop to make it easier to follow what's going on in the
+ # logs.
+ ctx.request = "testctx"
rv = yield f(*args, **kwargs)
defer.returnValue(rv)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1c253d0579..5ffba2ca7a 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -228,3 +228,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_not_support_user(self):
res = self.get_success(self.handler.register(localpart='user'))
self.assertFalse(self.store.is_support_user(res[0]))
+
+ def test_invalid_user_id_length(self):
+ invalid_user_id = "x" * 256
+ self.get_failure(
+ self.handler.register(localpart=invalid_user_id),
+ SynapseError
+ )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
new file mode 100644
index 0000000000..249aba3d59
--- /dev/null
+++ b/tests/handlers/test_stats.py
@@ -0,0 +1,251 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class StatsRoomTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+
+ self.store = hs.get_datastore()
+ self.handler = self.hs.get_stats_handler()
+
+ def _add_background_updates(self):
+ """
+ Add the background updates we need to run.
+ """
+ # Ugh, have to reset this flag
+ self.store._all_done = False
+
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_createtables",
+ },
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+
+ def test_initial_room(self):
+ """
+ The background updates will build the table from scratch.
+ """
+ r = self.get_success(self.store.get_all_room_state())
+ self.assertEqual(len(r), 0)
+
+ # Disable stats
+ self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Stats disabled, shouldn't have done anything
+ r = self.get_success(self.store.get_all_room_state())
+ self.assertEqual(len(r), 0)
+
+ # Enable stats
+ self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
+
+ # Do the initial population of the user directory via the background update
+ self._add_background_updates()
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ r = self.get_success(self.store.get_all_room_state())
+
+ self.assertEqual(len(r), 1)
+ self.assertEqual(r[0]["topic"], "foo")
+
+ def test_initial_earliest_token(self):
+ """
+ Ingestion via notify_new_event will ignore tokens that the background
+ update have already processed.
+ """
+ self.reactor.advance(86401)
+
+ self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ u2 = self.register_user("u2", "pass")
+ u2_token = self.login("u2", "pass")
+
+ u3 = self.register_user("u3", "pass")
+ u3_token = self.login("u3", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Begin the ingestion by creating the temp tables. This will also store
+ # the position that the deltas should begin at, once they take over.
+ self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
+ self.store._all_done = False
+ self.get_success(self.store.update_stats_stream_pos(None))
+
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ )
+ )
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ # Now, before the table is actually ingested, add some more events.
+ self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room=room_1, user=u2, tok=u2_token)
+
+ # Now do the initial ingestion.
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+
+ self.store._all_done = False
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ self.reactor.advance(86401)
+
+ # Now add some more events, triggering ingestion. Because of the stream
+ # position being set to before the events sent in the middle, a simpler
+ # implementation would reprocess those events, and say there were four
+ # users, not three.
+ self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
+ self.helper.join(room=room_1, user=u3, tok=u3_token)
+
+ # Get the deltas! There should be two -- day 1, and day 2.
+ r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+
+ # The oldest has 2 joined members
+ self.assertEqual(r[-1]["joined_members"], 2)
+
+ # The newest has 3
+ self.assertEqual(r[0]["joined_members"], 3)
+
+ def test_incorrect_state_transition(self):
+ """
+ If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to
+ (JOIN, INVITE, LEAVE, BAN), an error is raised.
+ """
+ events = {
+ "a1": {"membership": Membership.LEAVE},
+ "a2": {"membership": "not a real thing"},
+ }
+
+ def get_event(event_id):
+ m = Mock()
+ m.content = events[event_id]
+ d = defer.Deferred()
+ self.reactor.callLater(0.0, d.callback, m)
+ return d
+
+ def get_received_ts(event_id):
+ return defer.succeed(1)
+
+ self.store.get_received_ts = get_received_ts
+ self.store.get_event = get_event
+
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user",
+ "room_id": "room",
+ "event_id": "a1",
+ "prev_event_id": "a2",
+ "stream_id": "bleb",
+ }
+ ]
+
+ f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ self.assertEqual(
+ f.value.args[0], "'not a real thing' is not a valid prev_membership"
+ )
+
+ # And the other way...
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user",
+ "room_id": "room",
+ "event_id": "a2",
+ "prev_event_id": "a1",
+ "stream_id": "bleb",
+ }
+ ]
+
+ f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ self.assertEqual(
+ f.value.args[0], "'not a real thing' is not a valid membership"
+ )
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 05b0143c42..f7133fc12e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -127,3 +127,20 @@ class RestHelper(object):
)
return channel.json_body
+
+ def send_state(self, room_id, event_type, body, tok, expect_code=200):
+ path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8')
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index ad7d476401..b9ef46e8fb 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -92,7 +92,14 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.recaptcha_attempts), 1)
self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
- # Now we have fufilled the recaptcha fallback step, we can then send a
+ # also complete the dummy auth
+ request, channel = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ )
+ self.render(request)
+
+ # Now we should have fufilled a complete auth flow, including
+ # the recaptcha fallback step, we can then send a
# request to the register API with the session in the authdict.
request, channel = self.make_request(
"POST", "register", {"auth": {"session": session}}
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index f3ef977404..bce5b0cf4c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.rest.admin
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
@@ -32,6 +32,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
+ self.config = hs.config
return hs
def test_check_auth_required(self):
@@ -51,8 +52,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
for room_version in capabilities['m.room_versions']['available'].keys():
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
+
self.assertEqual(
- DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default']
+ self.config.default_room_version.identifier,
+ capabilities['m.room_versions']['default'],
)
def test_get_change_password_capabilities(self):
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
new file mode 100644
index 0000000000..43b3049daa
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -0,0 +1,564 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import json
+
+import six
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import register, relations
+
+from tests import unittest
+
+
+class RelationsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ relations.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ ]
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ # We need to enable msc1849 support for aggregations
+ config = self.default_config()
+ config["experimental_msc1849_support_enabled"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id, self.user_token = self._create_user("alice")
+ self.user2_id, self.user2_token = self._create_user("bob")
+
+ self.room = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ self.helper.join(self.room, user=self.user2_id, tok=self.user2_token)
+ res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
+ self.parent_id = res["event_id"]
+
+ def test_send_relation(self):
+ """Tests that sending a relation using the new /send_relation works
+ creates the right shape of event.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assert_dict(
+ {
+ "type": "m.reaction",
+ "sender": self.user_id,
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "key": u"👍",
+ "rel_type": RelationTypes.ANNOTATION,
+ }
+ },
+ },
+ channel.json_body,
+ )
+
+ def test_deny_membership(self):
+ """Test that we deny relations on membership events
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_deny_double_react(self):
+ """Test that we deny relations on membership events
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_basic_paginate_relations(self):
+ """Tests that calling pagination API corectly the latest relations.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the full
+ # relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
+ channel.json_body["chunk"][0],
+ )
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), six.string_types, channel.json_body
+ )
+
+ def test_repeated_paginate_relations(self):
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
+ """
+
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = None
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
+ % (self.room, self.parent_id, from_token),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ def test_aggregation_pagination_groups(self):
+ """Test that we can paginate annotation groups correctly.
+ """
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token = None
+ found_groups = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
+ % (self.room, self.parent_id, from_token),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEquals(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self):
+ """Test that we can paginate within an annotation group.
+ """
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=u"👍",
+ access_token=access_tokens[idx],
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ idx += 1
+
+ # Also send a different type of reaction so that we test we don't see it
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ prev_token = None
+ found_event_ids = []
+ encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8"))
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s"
+ "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
+ % (
+ self.room,
+ self.parent_id,
+ RelationTypes.ANNOTATION,
+ encoded_key,
+ from_token,
+ ),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ def test_aggregation(self):
+ """Test that annotations get correctly aggregated.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body,
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ )
+
+ def test_aggregation_redactions(self):
+ """Test that annotations get correctly aggregated after a redaction.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+ to_redact_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Now lets redact one of the 'a' reactions
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
+ access_token=self.user_token,
+ content={},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body,
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+ )
+
+ def test_aggregation_must_be_annotation(self):
+ """Test that aggregations must be annotations.
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
+ % (self.room, self.parent_id, RelationTypes.REPLACE),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_aggregation_get_event(self):
+ """Test that annotations and references get correctly bundled when
+ getting the parent event.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply_2 = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ RelationTypes.REFERENCE: {
+ "chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
+ },
+ },
+ )
+
+ def test_edit(self):
+ """Test that a simple edit works.
+ """
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ )
+
+ def test_multi_edit(self):
+ """Test that multiple edits, including attempts by people who
+ shouldn't be allowed, are correctly handled.
+ """
+
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message.WRONG_TYPE",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ )
+
+ def _send_relation(
+ self, relation_type, event_type, key=None, content={}, access_token=None
+ ):
+ """Helper function to send a relation pointing at `self.parent_id`
+
+ Args:
+ relation_type (str): One of `RelationTypes`
+ event_type (str): The type of the event to create
+ key (str|None): The aggregation key used for m.annotation relation
+ type.
+ content(dict|None): The content of the created event.
+ access_token (str|None): The access token used to send the relation,
+ defaults to `self.user_token`
+
+ Returns:
+ FakeChannel
+ """
+ if not access_token:
+ access_token = self.user_token
+
+ query = ""
+ if key:
+ query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
+ % (self.room, self.parent_id, relation_type, event_type, query),
+ json.dumps(content).encode("utf-8"),
+ access_token=access_token,
+ )
+ self.render(request)
+ return channel
+
+ def _create_user(self, localpart):
+ user_id = self.register_user(localpart, "abc123")
+ access_token = self.login(localpart, "abc123")
+
+ return user_id, access_token
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 6bfaa00fe9..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
from twisted.internet.defer import Deferred
+from synapse.storage.keys import FetchKeyResult
+
import tests.unittest
KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
store = self.hs.get_datastore()
- d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
- self.get_success(d)
- d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
+ key_id_1 = "ed25519:key1"
+ key_id_2 = "ed25519:KEY_ID_2"
+ d = store.store_server_verify_keys(
+ "from_server",
+ 10,
+ [
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
self.get_success(d)
d = store.get_server_verify_keys(
- [
- ("server1", "ed25519:key1"),
- ("server1", "ed25519:key2"),
- ("server1", "ed25519:key3"),
- ]
+ [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
)
res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
- self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
- self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
+ res1 = res[("server1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.verify_key.version, "key1")
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("server1", key_id_2)]
+ self.assertEqual(res2.verify_key, KEY_2)
+ # version comes from the ID it was stored with
+ self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+ self.assertEqual(res2.valid_until_ts, 200)
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
- d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
- self.get_success(d)
- d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
+ d = store.store_server_verify_keys(
+ "from_server",
+ 0,
+ [
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, KEY_2)
+ self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+ self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
- d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
+ d = store.store_server_verify_keys(
+ "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
+ )
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, new_key_2)
+ self.assertEqual(res2.valid_until_ts, 300)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index f412985d2c..52739fbabc 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -59,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
for flow in channel.json_body["flows"]:
self.assertIsInstance(flow["stages"], list)
self.assertTrue(len(flow["stages"]) > 0)
- self.assertEquals(flow["stages"][-1], "m.login.terms")
+ self.assertTrue("m.login.terms" in flow["stages"])
expected_params = {
"m.login.terms": {
|