diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 4c30c5f258..99b716f00a 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,9 +15,11 @@
from mock import Mock
+from synapse.rest.client.v1 import admin, login, room
from synapse.rest.client.v2_alpha import sync
from tests import unittest
+from tests.server import TimedOutException
class FilterTestCase(unittest.HomeserverTestCase):
@@ -65,3 +67,124 @@ class FilterTestCase(unittest.HomeserverTestCase):
["next_batch", "rooms", "account_data", "to_device", "device_lists"]
).issubset(set(channel.json_body.keys()))
)
+
+
+class SyncTypingTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+ user_id = True
+ hijack_auth = False
+
+ def test_sync_backwards_typing(self):
+ """
+ If the typing serial goes backwards and the typing handler is then reset
+ (such as when the master restarts and sets the typing serial to 0), we
+ do not incorrectly return typing information that had a serial greater
+ than the now-reset serial.
+ """
+ typing_url = "/rooms/%s/typing/%s?access_token=%s"
+ sync_url = "/sync?timeout=3000000&access_token=%s&since=%s"
+
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # Invite the other person
+ self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ request, channel = self.make_request(
+ "GET", "/sync?access_token=%s" % (access_token,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Stop typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": false}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Start typing.
+ request, channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Should return immediately
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Reset typing serial back to 0, as if the master had.
+ typing = self.hs.get_typing_handler()
+ typing._latest_room_serial = 0
+
+ # Since it checks the state token, we need some state to update to
+ # invalidate the stream token.
+ self.helper.send(room, body="There!", tok=other_access_token)
+
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # This should time out! But it does not, because our stream token is
+ # ahead, and therefore it's saying the typing (that we've actually
+ # already seen) is new, since it's got a token above our new, now-reset
+ # stream token.
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ next_batch = channel.json_body["next_batch"]
+
+ # Clear the typing information, so that it doesn't think everything is
+ # in the future.
+ typing._reset()
+
+ # Now it SHOULD fail as it never completes!
+ request, channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch)
+ )
+ self.assertRaises(TimedOutException, self.render, request)
diff --git a/tests/server.py b/tests/server.py
index 819c854448..cc6dbe04ac 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -21,6 +21,12 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
+class TimedOutException(Exception):
+ """
+ A web query timed out.
+ """
+
+
@attr.s
class FakeChannel(object):
"""
@@ -153,7 +159,7 @@ def wait_until_result(clock, request, timeout=100):
x += 1
if x > timeout:
- raise Exception("Timed out waiting for request to finish.")
+ raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index efd85ebe6c..d67f59b2c7 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -544,8 +544,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- self.assertTrue(state_d.called)
- state_before = state_d.result
+ state_before = self.successResultOf(state_d)
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -599,6 +598,103 @@ class LexicographicalTestCase(unittest.TestCase):
self.assertEqual(["o", "l", "n", "m", "p"], res)
+class SimpleParamStateTestCase(unittest.TestCase):
+ def setUp(self):
+ # We build up a simple DAG.
+
+ event_map = {}
+
+ create_event = FakeEvent(
+ id="CREATE",
+ sender=ALICE,
+ type=EventTypes.Create,
+ state_key="",
+ content={"creator": ALICE},
+ ).to_event([], [])
+ event_map[create_event.event_id] = create_event
+
+ alice_member = FakeEvent(
+ id="IMA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event([create_event.event_id], [create_event.event_id])
+ event_map[alice_member.event_id] = alice_member
+
+ join_rules = FakeEvent(
+ id="IJR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.PUBLIC},
+ ).to_event(
+ auth_events=[create_event.event_id, alice_member.event_id],
+ prev_events=[alice_member.event_id],
+ )
+ event_map[join_rules.event_id] = join_rules
+
+ # Bob and Charlie join at the same time, so there is a fork
+ bob_member = FakeEvent(
+ id="IMB",
+ sender=BOB,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event(
+ auth_events=[create_event.event_id, join_rules.event_id],
+ prev_events=[join_rules.event_id],
+ )
+ event_map[bob_member.event_id] = bob_member
+
+ charlie_member = FakeEvent(
+ id="IMC",
+ sender=CHARLIE,
+ type=EventTypes.Member,
+ state_key=CHARLIE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ).to_event(
+ auth_events=[create_event.event_id, join_rules.event_id],
+ prev_events=[join_rules.event_id],
+ )
+ event_map[charlie_member.event_id] = charlie_member
+
+ self.event_map = event_map
+ self.create_event = create_event
+ self.alice_member = alice_member
+ self.join_rules = join_rules
+ self.bob_member = bob_member
+ self.charlie_member = charlie_member
+
+ self.state_at_bob = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, bob_member]
+ }
+
+ self.state_at_charlie = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, charlie_member]
+ }
+
+ self.expected_combined_state = {
+ (e.type, e.state_key): e.event_id
+ for e in [create_event, alice_member, join_rules, bob_member, charlie_member]
+ }
+
+ def test_event_map_none(self):
+ # Test that we correctly handle passing `None` as the event_map
+
+ state_d = resolve_events_with_store(
+ [self.state_at_bob, self.state_at_charlie],
+ event_map=None,
+ state_res_store=TestStateResolutionStore(self.event_map),
+ )
+
+ state = self.successResultOf(state_d)
+
+ self.assert_dict(self.expected_combined_state, state)
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
new file mode 100644
index 0000000000..7deab5266f
--- /dev/null
+++ b/tests/test_terms_auth.py
@@ -0,0 +1,123 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import six
+from mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactorClock
+
+from synapse.rest.client.v2_alpha.register import register_servlets
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import make_request
+
+
+class TermsTestCase(unittest.HomeserverTestCase):
+ servlets = [register_servlets]
+
+ def prepare(self, reactor, clock, hs):
+ self.clock = MemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+ self.url = "/_matrix/client/r0/register"
+ self.registration_handler = Mock()
+ self.auth_handler = Mock()
+ self.device_handler = Mock()
+ hs.config.enable_registration = True
+ hs.config.registrations_require_3pid = []
+ hs.config.auto_join_rooms = []
+ hs.config.enable_registration_captcha = False
+
+ def test_ui_auth(self):
+ self.hs.config.block_events_without_consent_error = True
+ self.hs.config.public_baseurl = "https://example.org"
+ self.hs.config.user_consent_version = "1.0"
+
+ # Do a UI auth request
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+
+ self.assertTrue(channel.json_body is not None)
+ self.assertIsInstance(channel.json_body["session"], six.text_type)
+
+ self.assertIsInstance(channel.json_body["flows"], list)
+ for flow in channel.json_body["flows"]:
+ self.assertIsInstance(flow["stages"], list)
+ self.assertTrue(len(flow["stages"]) > 0)
+ self.assertEquals(flow["stages"][-1], "m.login.terms")
+
+ expected_params = {
+ "m.login.terms": {
+ "policies": {
+ "privacy_policy": {
+ "en": {
+ "name": "Privacy Policy",
+ "url": "https://example.org/_matrix/consent?v=1.0",
+ },
+ "version": "1.0"
+ },
+ },
+ },
+ }
+ self.assertIsInstance(channel.json_body["params"], dict)
+ self.assertDictContainsSubset(channel.json_body["params"], expected_params)
+
+ # We have to complete the dummy auth stage before completing the terms stage
+ request_data = json.dumps(
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.dummy",
+ },
+ }
+ )
+
+ self.registration_handler.check_username = Mock(return_value=True)
+
+ request, channel = make_request(b"POST", self.url, request_data)
+ self.render(request)
+
+ # We don't bother checking that the response is correct - we'll leave that to
+ # other tests. We just want to make sure we're on the right path.
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+
+ # Finish the UI auth for terms
+ request_data = json.dumps(
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "auth": {
+ "session": channel.json_body["session"],
+ "type": "m.login.terms",
+ },
+ }
+ )
+ request, channel = make_request(b"POST", self.url, request_data)
+ self.render(request)
+
+ # We're interested in getting a response that looks like a successful
+ # registration, not so much that the details are exactly what we want.
+
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.assertTrue(channel.json_body is not None)
+ self.assertIsInstance(channel.json_body["user_id"], six.text_type)
+ self.assertIsInstance(channel.json_body["access_token"], six.text_type)
+ self.assertIsInstance(channel.json_body["device_id"], six.text_type)
|