diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
new file mode 100644
index 0000000000..fc37c4328c
--- /dev/null
+++ b/tests/handlers/test_admin.py
@@ -0,0 +1,210 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import Counter
+
+from mock import Mock
+
+import synapse.api.errors
+import synapse.handlers.admin
+import synapse.rest.admin
+import synapse.storage
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class ExfiltrateData(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ self.user1 = self.register_user("user1", "password")
+ self.token1 = self.login("user1", "password")
+
+ self.user2 = self.register_user("user2", "password")
+ self.token2 = self.login("user2", "password")
+
+ def test_single_public_joined_room(self):
+ """Test that we write *all* events for a public room
+ """
+ room_id = self.helper.create_room_as(
+ self.user1, tok=self.token1, is_public=True
+ )
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called()
+
+ # Since we can see all events there shouldn't be any extremities, so no
+ # state should be written
+ writer.write_state.assert_not_called()
+
+ # Collect all events that were written
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 2)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
+
+ def test_single_private_joined_room(self):
+ """Tests that we correctly write state when we can't see all events in
+ a room.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": "joined"},
+ tok=self.token1,
+ )
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called()
+
+ # Since we can't see all events there should be one extremity.
+ writer.write_state.assert_called_once()
+
+ # Collect all events that were written
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
+
+ def test_single_left_room(self):
+ """Tests that we don't see events in the room after we leave.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+ self.helper.leave(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Helloooooo!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called()
+
+ # Since we can see all events there shouldn't be any extremities, so no
+ # state should be written
+ writer.write_state.assert_not_called()
+
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 2)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
+
+ def test_single_left_rejoined_private_room(self):
+ """Tests that see the correct events in private rooms when we
+ repeatedly join and leave.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomHistoryVisibility,
+ body={"history_visibility": "joined"},
+ tok=self.token1,
+ )
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Hello again!", tok=self.token1)
+ self.helper.leave(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Helloooooo!", tok=self.token1)
+ self.helper.join(room_id, self.user2, tok=self.token2)
+ self.helper.send(room_id, body="Helloooooo!!", tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_called_once()
+
+ # Since we joined/left/joined again we expect there to be two gaps.
+ self.assertEqual(writer.write_state.call_count, 2)
+
+ written_events = []
+ for (called_room_id, events), _ in writer.write_events.call_args_list:
+ self.assertEqual(called_room_id, room_id)
+ written_events.extend(events)
+
+ # Check that the right number of events were written
+ counter = Counter(
+ (event.type, getattr(event, "state_key", None)) for event in written_events
+ )
+ self.assertEqual(counter[(EventTypes.Message, None)], 2)
+ self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
+ self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
+
+ def test_invite(self):
+ """Tests that pending invites get handled correctly.
+ """
+ room_id = self.helper.create_room_as(self.user1, tok=self.token1)
+ self.helper.send(room_id, body="Hello!", tok=self.token1)
+ self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_events.assert_not_called()
+ writer.write_state.assert_not_called()
+ writer.write_invite.assert_called_once()
+
+ args = writer.write_invite.call_args[0]
+ self.assertEqual(args[0], room_id)
+ self.assertEqual(args[1].content["membership"], "invite")
+ self.assertTrue(args[2]) # Assert there is at least one bit of state
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 1e39fe0ec2..b03103d96f 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -117,7 +117,9 @@ class AuthTestCase(unittest.TestCase):
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.auth_handler.get_access_token_for_user_id('user_a')
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
@@ -131,7 +133,9 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id('user_a')
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
@@ -150,7 +154,9 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id('user_a')
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
@@ -166,7 +172,9 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- yield self.auth_handler.get_access_token_for_user_id('user_a')
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
@@ -185,7 +193,9 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
- yield self.auth_handler.get_access_token_for_user_id('user_a')
+ yield self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index a3aa0a1cf2..62b47f6574 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -160,6 +160,24 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
+ def test_update_device_too_long_display_name(self):
+ """Update a device with a display name that is invalid (too long)."""
+ self._record_users()
+
+ # Request to update a device display name with a new value that is longer than allowed.
+ update = {
+ "display_name": "a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
+ }
+ self.get_failure(
+ self.handler.update_device(user1, "abc", update),
+ synapse.api.errors.SynapseError,
+ )
+
+ # Ensure the display name was not updated.
+ res = self.get_success(self.handler.get_device(user1, "abc"))
+ self.assertEqual(res["display_name"], "display 2")
+
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
res = self.handler.update_device("user_id", "unknown_device_id", update)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 917548bb31..5e40adba52 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -18,25 +18,20 @@ from mock import Mock
from twisted.internet import defer
+import synapse
+import synapse.api.errors
+from synapse.api.constants import EventTypes
from synapse.config.room_directory import RoomDirectoryConfig
-from synapse.handlers.directory import DirectoryHandler
-from synapse.rest.client.v1 import directory, room
-from synapse.types import RoomAlias
+from synapse.rest.client.v1 import directory, login, room
+from synapse.types import RoomAlias, create_requester
from tests import unittest
-from tests.utils import setup_test_homeserver
-class DirectoryHandlers(object):
- def __init__(self, hs):
- self.directory_handler = DirectoryHandler(hs)
-
-
-class DirectoryTestCase(unittest.TestCase):
+class DirectoryTestCase(unittest.HomeserverTestCase):
""" Tests the directory service. """
- @defer.inlineCallbacks
- def setUp(self):
+ def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
self.mock_registry = Mock()
@@ -47,14 +42,12 @@ class DirectoryTestCase(unittest.TestCase):
self.mock_registry.register_query_handler = register_query_handler
- hs = yield setup_test_homeserver(
- self.addCleanup,
+ hs = self.setup_test_homeserver(
http_client=None,
resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
- hs.handlers = DirectoryHandlers(hs)
self.handler = hs.get_handlers().directory_handler
@@ -64,23 +57,25 @@ class DirectoryTestCase(unittest.TestCase):
self.your_room = RoomAlias.from_string("#your-room:test")
self.remote_room = RoomAlias.from_string("#another:remote")
- @defer.inlineCallbacks
+ return hs
+
def test_get_local_association(self):
- yield self.store.create_room_alias_association(
- self.my_room, "!8765qwer:test", ["test"]
+ self.get_success(
+ self.store.create_room_alias_association(
+ self.my_room, "!8765qwer:test", ["test"]
+ )
)
- result = yield self.handler.get_association(self.my_room)
+ result = self.get_success(self.handler.get_association(self.my_room))
self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
- @defer.inlineCallbacks
def test_get_remote_association(self):
self.mock_federation.make_query.return_value = defer.succeed(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
- result = yield self.handler.get_association(self.remote_room)
+ result = self.get_success(self.handler.get_association(self.remote_room))
self.assertEquals(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result
@@ -93,19 +88,241 @@ class DirectoryTestCase(unittest.TestCase):
ignore_backoff=True,
)
- @defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_room_alias_association(
- self.your_room, "!8765asdf:test", ["test"]
+ self.get_success(
+ self.store.create_room_alias_association(
+ self.your_room, "!8765asdf:test", ["test"]
+ )
)
- response = yield self.query_handlers["directory"](
- {"room_alias": "#your-room:test"}
+ response = self.get_success(
+ self.handler.on_directory_query({"room_alias": "#your-room:test"})
)
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+class TestDeleteAlias(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_handlers().directory_handler
+ self.state_handler = hs.get_state_handler()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ # Create a test room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.test_alias = "#test:test"
+ self.room_alias = RoomAlias.from_string(self.test_alias)
+
+ # Create a test user.
+ self.test_user = self.register_user("user", "pass", admin=False)
+ self.test_user_tok = self.login("user", "pass")
+ self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
+
+ def _create_alias(self, user):
+ # Create a new alias to this room.
+ self.get_success(
+ self.store.create_room_alias_association(
+ self.room_alias, self.room_id, ["test"], user
+ )
+ )
+
+ def test_delete_alias_not_allowed(self):
+ """A user that doesn't meet the expected guidelines cannot delete an alias."""
+ self._create_alias(self.admin_user)
+ self.get_failure(
+ self.handler.delete_association(
+ create_requester(self.test_user), self.room_alias
+ ),
+ synapse.api.errors.AuthError,
+ )
+
+ def test_delete_alias_creator(self):
+ """An alias creator can delete their own alias."""
+ # Create an alias from a different user.
+ self._create_alias(self.test_user)
+
+ # Delete the user's alias.
+ result = self.get_success(
+ self.handler.delete_association(
+ create_requester(self.test_user), self.room_alias
+ )
+ )
+ self.assertEquals(self.room_id, result)
+
+ # Confirm the alias is gone.
+ self.get_failure(
+ self.handler.get_association(self.room_alias),
+ synapse.api.errors.SynapseError,
+ )
+
+ def test_delete_alias_admin(self):
+ """A server admin can delete an alias created by another user."""
+ # Create an alias from a different user.
+ self._create_alias(self.test_user)
+
+ # Delete the user's alias as the admin.
+ result = self.get_success(
+ self.handler.delete_association(
+ create_requester(self.admin_user), self.room_alias
+ )
+ )
+ self.assertEquals(self.room_id, result)
+
+ # Confirm the alias is gone.
+ self.get_failure(
+ self.handler.get_association(self.room_alias),
+ synapse.api.errors.SynapseError,
+ )
+
+ def test_delete_alias_sufficient_power(self):
+ """A user with a sufficient power level should be able to delete an alias."""
+ self._create_alias(self.admin_user)
+
+ # Increase the user's power level.
+ self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ {"users": {self.test_user: 100}},
+ tok=self.admin_user_tok,
+ )
+
+ # They can now delete the alias.
+ result = self.get_success(
+ self.handler.delete_association(
+ create_requester(self.test_user), self.room_alias
+ )
+ )
+ self.assertEquals(self.room_id, result)
+
+ # Confirm the alias is gone.
+ self.get_failure(
+ self.handler.get_association(self.room_alias),
+ synapse.api.errors.SynapseError,
+ )
+
+
+class CanonicalAliasTestCase(unittest.HomeserverTestCase):
+ """Test modifications of the canonical alias when delete aliases.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_handlers().directory_handler
+ self.state_handler = hs.get_state_handler()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ # Create a test room
+ self.room_id = self.helper.create_room_as(
+ self.admin_user, tok=self.admin_user_tok
+ )
+
+ self.test_alias = "#test:test"
+ self.room_alias = self._add_alias(self.test_alias)
+
+ def _add_alias(self, alias: str) -> RoomAlias:
+ """Add an alias to the test room."""
+ room_alias = RoomAlias.from_string(alias)
+
+ # Create a new alias to this room.
+ self.get_success(
+ self.store.create_room_alias_association(
+ room_alias, self.room_id, ["test"], self.admin_user
+ )
+ )
+ return room_alias
+
+ def _set_canonical_alias(self, content):
+ """Configure the canonical alias state on the room."""
+ self.helper.send_state(
+ self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
+ )
+
+ def _get_canonical_alias(self):
+ """Get the canonical alias state of the room."""
+ return self.get_success(
+ self.state_handler.get_current_state(
+ self.room_id, EventTypes.CanonicalAlias, ""
+ )
+ )
+
+ def test_remove_alias(self):
+ """Removing an alias that is the canonical alias should remove it there too."""
+ # Set this new alias as the canonical alias for this room
+ self._set_canonical_alias(
+ {"alias": self.test_alias, "alt_aliases": [self.test_alias]}
+ )
+
+ data = self._get_canonical_alias()
+ self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+
+ # Finally, delete the alias.
+ self.get_success(
+ self.handler.delete_association(
+ create_requester(self.admin_user), self.room_alias
+ )
+ )
+
+ data = self._get_canonical_alias()
+ self.assertNotIn("alias", data["content"])
+ self.assertNotIn("alt_aliases", data["content"])
+
+ def test_remove_other_alias(self):
+ """Removing an alias listed as in alt_aliases should remove it there too."""
+ # Create a second alias.
+ other_test_alias = "#test2:test"
+ other_room_alias = self._add_alias(other_test_alias)
+
+ # Set the alias as the canonical alias for this room.
+ self._set_canonical_alias(
+ {
+ "alias": self.test_alias,
+ "alt_aliases": [self.test_alias, other_test_alias],
+ }
+ )
+
+ data = self._get_canonical_alias()
+ self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(
+ data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
+ )
+
+ # Delete the second alias.
+ self.get_success(
+ self.handler.delete_association(
+ create_requester(self.admin_user), other_room_alias
+ )
+ )
+
+ data = self._get_canonical_alias()
+ self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+
+
class TestCreateAliasACL(unittest.HomeserverTestCase):
user_id = "@test:test"
@@ -132,7 +349,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT",
b"directory/room/%23test%3Atest",
- ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
self.render(request)
self.assertEquals(403, channel.code, channel.result)
@@ -143,7 +360,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT",
b"directory/room/%23unofficial_test%3Atest",
- ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
@@ -158,7 +375,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.user_id)
request, channel = self.make_request(
- "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
+ "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
@@ -190,7 +407,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list disabled so we shouldn't be allowed to publish rooms
room_id = self.helper.create_room_as(self.user_id)
request, channel = self.make_request(
- "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}'
+ "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.render(request)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8dccc6826e..854eb6c024 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,9 +17,11 @@
import mock
+import signedjson.key as key
+import signedjson.sign as sign
+
from twisted.internet import defer
-import synapse.api.errors
import synapse.handlers.e2e_keys
import synapse.storage
from synapse.api import errors
@@ -145,3 +149,357 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
},
)
+
+ @defer.inlineCallbacks
+ def test_replace_master_key(self):
+ """uploading a new signing key should make the old signing key unavailable"""
+ local_user = "@boris:" + self.hs.hostname
+ keys1 = {
+ "master_key": {
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ },
+ }
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+
+ keys2 = {
+ "master_key": {
+ # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw": "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw"
+ },
+ }
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+
+ devices = yield self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user
+ )
+ self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
+
+ @defer.inlineCallbacks
+ def test_reupload_signatures(self):
+ """re-uploading a signature should not fail"""
+ local_user = "@boris:" + self.hs.hostname
+ keys1 = {
+ "master_key": {
+ # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ": "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ"
+ },
+ },
+ "self_signing_key": {
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ "user_id": local_user,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ },
+ },
+ }
+ master_signing_key = key.decode_signing_key_base64(
+ "ed25519",
+ "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ",
+ "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8",
+ )
+ sign.sign_json(keys1["self_signing_key"], local_user, master_signing_key)
+ signing_key = key.decode_signing_key_base64(
+ "ed25519",
+ "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
+ )
+ yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+
+ # upload two device keys, which will be signed later by the self-signing key
+ device_key_1 = {
+ "user_id": local_user,
+ "device_id": "abc",
+ "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "keys": {
+ "ed25519:abc": "base64+ed25519+key",
+ "curve25519:abc": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
+ }
+ device_key_2 = {
+ "user_id": local_user,
+ "device_id": "def",
+ "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "keys": {
+ "ed25519:def": "base64+ed25519+key",
+ "curve25519:def": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:def": "base64+signature"}},
+ }
+
+ yield self.handler.upload_keys_for_user(
+ local_user, "abc", {"device_keys": device_key_1}
+ )
+ yield self.handler.upload_keys_for_user(
+ local_user, "def", {"device_keys": device_key_2}
+ )
+
+ # sign the first device key and upload it
+ del device_key_1["signatures"]
+ sign.sign_json(device_key_1, local_user, signing_key)
+ yield self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1}}
+ )
+
+ # sign the second device key and upload both device keys. The server
+ # should ignore the first device key since it already has a valid
+ # signature for it
+ del device_key_2["signatures"]
+ sign.sign_json(device_key_2, local_user, signing_key)
+ yield self.handler.upload_signatures_for_device_keys(
+ local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+ )
+
+ device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
+ device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
+ devices = yield self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user
+ )
+ del devices["device_keys"][local_user]["abc"]["unsigned"]
+ del devices["device_keys"][local_user]["def"]["unsigned"]
+ self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
+ self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
+
+ @defer.inlineCallbacks
+ def test_self_signing_key_doesnt_show_up_as_device(self):
+ """signing keys should be hidden when fetching a user's devices"""
+ local_user = "@boris:" + self.hs.hostname
+ keys1 = {
+ "master_key": {
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {
+ "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ },
+ }
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+
+ res = None
+ try:
+ yield self.hs.get_device_handler().check_device_registered(
+ user_id=local_user,
+ device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+ initial_device_display_name="new display name",
+ )
+ except errors.SynapseError as e:
+ res = e.code
+ self.assertEqual(res, 400)
+
+ res = yield self.handler.query_local_devices({local_user: None})
+ self.assertDictEqual(res, {local_user: {}})
+
+ @defer.inlineCallbacks
+ def test_upload_signatures(self):
+ """should check signatures that are uploaded"""
+ # set up a user with cross-signing keys and a device. This user will
+ # try uploading signatures
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
+ device_key = {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
+ "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
+ "signatures": {local_user: {"ed25519:xyz": "something"}},
+ }
+ device_signing_key = key.decode_signing_key_base64(
+ "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
+ )
+
+ yield self.handler.upload_keys_for_user(
+ local_user, device_id, {"device_keys": device_key}
+ )
+
+ # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
+ master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
+ master_key = {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ }
+ master_signing_key = key.decode_signing_key_base64(
+ "ed25519", master_pubkey, "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0"
+ )
+ usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw"
+ usersigning_key = {
+ # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs
+ "user_id": local_user,
+ "usage": ["user_signing"],
+ "keys": {"ed25519:" + usersigning_pubkey: usersigning_pubkey},
+ }
+ usersigning_signing_key = key.decode_signing_key_base64(
+ "ed25519", usersigning_pubkey, "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs"
+ )
+ sign.sign_json(usersigning_key, local_user, master_signing_key)
+ # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
+ selfsigning_pubkey = "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ"
+ selfsigning_key = {
+ "user_id": local_user,
+ "usage": ["self_signing"],
+ "keys": {"ed25519:" + selfsigning_pubkey: selfsigning_pubkey},
+ }
+ selfsigning_signing_key = key.decode_signing_key_base64(
+ "ed25519", selfsigning_pubkey, "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8"
+ )
+ sign.sign_json(selfsigning_key, local_user, master_signing_key)
+ cross_signing_keys = {
+ "master_key": master_key,
+ "user_signing_key": usersigning_key,
+ "self_signing_key": selfsigning_key,
+ }
+ yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+
+ # set up another user with a master key. This user will be signed by
+ # the first user
+ other_user = "@otherboris:" + self.hs.hostname
+ other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
+ other_master_key = {
+ # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
+ }
+ yield self.handler.upload_signing_keys_for_user(
+ other_user, {"master_key": other_master_key}
+ )
+
+ # test various signature failures (see below)
+ ret = yield self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ device_id: {
+ "user_id": local_user,
+ "device_id": device_id,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha256",
+ "m.megolm.v1.aes-sha",
+ ],
+ "keys": {
+ "curve25519:xyz": "curve25519+key",
+ # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+ "ed25519:xyz": device_pubkey,
+ },
+ "signatures": {
+ local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ },
+ },
+ # fails because device is unknown
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": local_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+ },
+ },
+ # fails because the signature is invalid
+ # should fail with INVALID_SIGNATURE
+ master_pubkey: {
+ "user_id": local_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_pubkey: master_pubkey},
+ "signatures": {
+ local_user: {"ed25519:" + device_pubkey: "something"}
+ },
+ },
+ },
+ other_user: {
+ # fails because the device is not the user's master-signing key
+ # should fail with NOT_FOUND
+ "unknown": {
+ "user_id": other_user,
+ "device_id": "unknown",
+ "signatures": {
+ local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ },
+ },
+ other_master_pubkey: {
+ # fails because the key doesn't match what the server has
+ # should fail with UNKNOWN
+ "user_id": other_user,
+ "usage": ["master"],
+ "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
+ "something": "random",
+ "signatures": {
+ local_user: {"ed25519:" + usersigning_pubkey: "something"}
+ },
+ },
+ },
+ },
+ )
+
+ user_failures = ret["failures"][local_user]
+ self.assertEqual(
+ user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+ )
+ self.assertEqual(
+ user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
+ )
+ self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+
+ other_user_failures = ret["failures"][other_user]
+ self.assertEqual(
+ other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
+ )
+ self.assertEqual(
+ other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+ )
+
+ # test successful signatures
+ del device_key["signatures"]
+ sign.sign_json(device_key, local_user, selfsigning_signing_key)
+ sign.sign_json(master_key, local_user, device_signing_key)
+ sign.sign_json(other_master_key, local_user, usersigning_signing_key)
+ ret = yield self.handler.upload_signatures_for_device_keys(
+ local_user,
+ {
+ local_user: {device_id: device_key, master_pubkey: master_key},
+ other_user: {other_master_pubkey: other_master_key},
+ },
+ )
+
+ self.assertEqual(ret["failures"], {})
+
+ # fetch the signed keys/devices and make sure that the signatures are there
+ ret = yield self.handler.query_devices(
+ {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ )
+
+ self.assertEqual(
+ ret["device_keys"][local_user]["xyz"]["signatures"][local_user][
+ "ed25519:" + selfsigning_pubkey
+ ],
+ device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey],
+ )
+ self.assertEqual(
+ ret["master_keys"][local_user]["signatures"][local_user][
+ "ed25519:" + device_id
+ ],
+ master_key["signatures"][local_user]["ed25519:" + device_id],
+ )
+ self.assertEqual(
+ ret["master_keys"][other_user]["signatures"][local_user][
+ "ed25519:" + usersigning_pubkey
+ ],
+ other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
+ )
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 2e72a1dd23..70f172eb02 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ version_etag = res["etag"]
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
# check we can retrieve it as a specific version
res = yield self.handler.get_version_info(self.local_user, "1")
+ self.assertEqual(res["etag"], version_etag)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "1",
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
+ "count": 0,
},
)
@@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"version": "2",
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
+ "count": 0,
},
)
@@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"]
self.assertDictEqual(
res,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
+ "count": 0,
},
)
@@ -187,9 +198,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, 404)
@defer.inlineCallbacks
- def test_update_bad_version(self):
- """Check that we get a 400 if the version in the body is missing or
- doesn't match
+ def test_update_omitted_version(self):
+ """Check that the update succeeds if the version is missing from the body
"""
version = yield self.handler.create_version(
self.local_user,
@@ -197,19 +207,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
self.assertEqual(version, "1")
- res = None
- try:
- yield self.handler.update_version(
- self.local_user,
- version,
- {
- "algorithm": "m.megolm_backup.v1",
- "auth_data": "revised_first_version_auth_data",
- },
- )
- except errors.SynapseError as e:
- res = e.code
- self.assertEqual(res, 400)
+ yield self.handler.update_version(
+ self.local_user,
+ version,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ },
+ )
+
+ # check we can retrieve it as the current version
+ res = yield self.handler.get_version_info(self.local_user)
+ del res["etag"] # etag is opaque, so don't test its contents
+ self.assertDictEqual(
+ res,
+ {
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "revised_first_version_auth_data",
+ "version": version,
+ "count": 0,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def test_update_bad_version(self):
+ """Check that we get a 400 if the version in the body doesn't match
+ """
+ version = yield self.handler.create_version(
+ self.local_user,
+ {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+ )
+ self.assertEqual(version, "1")
res = None
try:
@@ -394,40 +422,58 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+ # get the etag to compare to future versions
+ res = yield self.handler.get_version_info(self.local_user)
+ backup_etag = res["etag"]
+ self.assertEqual(res["count"], 1)
+
new_room_keys = copy.deepcopy(room_keys)
- new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33']
+ new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]
# test that increasing the message_index doesn't replace the existing session
- new_room_key['first_message_index'] = 2
- new_room_key['session_data'] = 'new'
+ new_room_key["first_message_index"] = 2
+ new_room_key["session_data"] = "new"
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
res = yield self.handler.get_room_keys(self.local_user, version)
self.assertEqual(
- res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'],
+ res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# test that marking the session as verified however /does/ replace it
- new_room_key['is_verified'] = True
+ new_room_key["is_verified"] = True
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
res = yield self.handler.get_room_keys(self.local_user, version)
self.assertEqual(
- res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
+ res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should NOT be equal now, since the key changed
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertNotEqual(res["etag"], backup_etag)
+ backup_etag = res["etag"]
+
# test that a session with a higher forwarded_count doesn't replace one
# with a lower forwarding count
- new_room_key['forwarded_count'] = 2
- new_room_key['session_data'] = 'other'
+ new_room_key["forwarded_count"] = 2
+ new_room_key["session_data"] = "other"
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
res = yield self.handler.get_room_keys(self.local_user, version)
self.assertEqual(
- res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new"
+ res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
+ # the etag should be the same since the session did not change
+ res = yield self.handler.get_version_info(self.local_user)
+ self.assertEqual(res["etag"], backup_etag)
+
# TODO: check edge cases as well as the common variations here
@defer.inlineCallbacks
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index b1ae15a2bd..132e35651d 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
+logger = logging.getLogger(__name__)
+
class FederationTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -50,7 +56,6 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
d = self.handler.on_exchange_third_party_invite_request(
- origin="example.com",
room_id=room_id,
event_dict={
"type": EventTypes.Member,
@@ -66,10 +71,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"token": invite_token,
"signatures": {
"magic.forest": {
- "ed25519:3": (
- "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIs"
- "NffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
- )
+ "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
}
},
},
@@ -83,3 +85,125 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.code, 403, failure)
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
+
+ def test_rejected_message_event_state(self):
+ """
+ Check that we store the state group correctly for rejected non-state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def test_rejected_state_event_state(self):
+ """
+ Check that we store the state group correctly for rejected state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": "org.matrix.test",
+ "state_key": "test_key",
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def _build_and_send_join_event(self, other_server, other_user, room_id):
+ join_event = self.get_success(
+ self.handler.on_make_join_request(other_server, room_id, other_user)
+ )
+ # the auth code requires that a signature exists, but doesn't check that
+ # signature... go figure.
+ join_event.signatures[other_server] = {"x": "y"}
+ with LoggingContext(request="send_join"):
+ d = run_in_background(
+ self.handler.on_send_join_request, other_server, join_event
+ )
+ self.get_success(d)
+
+ # sanity-check: the room should show that the new user is a member
+ r = self.get_success(self.store.get_current_state_ids(room_id))
+ self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
+
+ return join_event
diff --git a/tests/handlers/test_identity.py b/tests/handlers/test_identity.py
index 99ce94db52..0ab0356109 100644
--- a/tests/handlers/test_identity.py
+++ b/tests/handlers/test_identity.py
@@ -35,27 +35,36 @@ class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.address = "test@test"
self.is_server_name = "testis"
- self.rewritten_is_url = "int.testis"
+ self.is_server_url = "https://testis"
+ self.rewritten_is_url = "https://int.testis"
config = self.default_config()
- config["trusted_third_party_id_servers"] = [
- self.is_server_name,
- ]
+ config["trusted_third_party_id_servers"] = [self.is_server_name]
config["rewrite_identity_server_urls"] = {
- self.is_server_name: self.rewritten_is_url,
+ self.is_server_url: self.rewritten_is_url
}
- mock_http_client = Mock(spec=[
- "post_urlencoded_get_json",
- ])
- mock_http_client.post_urlencoded_get_json.return_value = defer.succeed({
- "address": self.address,
- "medium": "email",
- })
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.side_effect = defer.succeed({})
+ mock_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
self.hs = self.setup_test_homeserver(
- config=config,
- simple_http_client=mock_http_client,
+ config=config, simple_http_client=mock_http_client
+ )
+
+ mock_blacklisting_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_blacklisting_http_client.get_json.side_effect = defer.succeed({})
+ mock_blacklisting_http_client.post_json_get_json.return_value = defer.succeed(
+ {"address": self.address, "medium": "email"}
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_blacklisting_http_client
)
return self.hs
@@ -71,38 +80,37 @@ class ThreepidISRewrittenURLTestCase(unittest.HomeserverTestCase):
* the original, non-rewritten, server name is stored in the database
"""
handler = self.hs.get_handlers().identity_handler
- post_urlenc_get_json = self.hs.get_simple_http_client().post_urlencoded_get_json
+ post_json_get_json = handler.blacklisting_http_client.post_json_get_json
store = self.hs.get_datastore()
- creds = {
- "sid": "123",
- "client_secret": "some_secret",
- }
+ creds = {"sid": "123", "client_secret": "some_secret"}
# Make sure processing the mocked response goes through.
- data = self.get_success(handler.bind_threepid(
- {
- "id_server": self.is_server_name,
- "client_secret": creds["client_secret"],
- "sid": creds["sid"],
- },
- self.user_id,
- ))
+ data = self.get_success(
+ handler.bind_threepid(
+ client_secret=creds["client_secret"],
+ sid=creds["sid"],
+ mxid=self.user_id,
+ id_server=self.is_server_name,
+ use_v2=False,
+ )
+ )
self.assertEqual(data.get("address"), self.address)
# Check that the request was done against the rewritten server name.
- post_urlenc_get_json.assert_called_once_with(
- "https://%s/_matrix/identity/api/v1/3pid/bind" % self.rewritten_is_url,
+ post_json_get_json.assert_called_once_with(
+ "%s/_matrix/identity/api/v1/3pid/bind" % (self.rewritten_is_url,),
{
- 'sid': creds['sid'],
- 'client_secret': creds["client_secret"],
- 'mxid': self.user_id,
- }
+ "sid": creds["sid"],
+ "client_secret": creds["client_secret"],
+ "mxid": self.user_id,
+ },
+ headers={},
)
# Check that the original server name is saved in the database instead of the
# rewritten one.
- id_servers = self.get_success(store.get_id_servers_user_bound(
- self.user_id, "email", self.address
- ))
+ id_servers = self.get_success(
+ store.get_id_servers_user_bound(self.user_id, "email", self.address)
+ )
self.assertEqual(id_servers, [self.is_server_name])
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index f70c6e7d65..05ea40a7de 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,9 +19,10 @@ from mock import Mock, call
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
-from synapse.events import room_version_to_event_format
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
+ EXTERNAL_PROCESS_EXPIRY,
FEDERATION_PING_INTERVAL,
FEDERATION_TIMEOUT,
IDLE_TIMER,
@@ -337,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
new_state = handle_timeout(
- state, is_mine=True, syncing_user_ids=set([user_id]), now=now
+ state, is_mine=True, syncing_user_ids={user_id}, now=now
)
self.assertIsNotNone(new_state)
@@ -413,6 +414,44 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEquals(state, new_state)
+class PresenceHandlerTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.presence_handler = hs.get_presence_handler()
+ self.clock = hs.get_clock()
+
+ def test_external_process_timeout(self):
+ """Test that if an external process doesn't update the records for a while
+ we time out their syncing users presence.
+ """
+ process_id = 1
+ user_id = "@test:server"
+
+ # Notify handler that a user is now syncing.
+ self.get_success(
+ self.presence_handler.update_external_syncs_row(
+ process_id, user_id, True, self.clock.time_msec()
+ )
+ )
+
+ # Check that if we wait a while without telling the handler the user has
+ # stopped syncing that their presence state doesn't get timed out.
+ self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2)
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
+ # Check that if the external process timeout fires, then the syncing
+ # user gets timed out
+ self.reactor.advance(EXTERNAL_PROCESS_EXPIRY)
+
+ state = self.get_success(
+ self.presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+
+
class PresenceJoinTestCase(unittest.HomeserverTestCase):
"""Tests remote servers get told about presence of users in the room when
they join and when new local users join.
@@ -455,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, "@test2:server")
# Mark test2 as online, test will be offline with a last_active of 0
- self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ )
)
self.reactor.pump([0]) # Wait for presence updates to be handled
@@ -504,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
room_id = self.helper.create_room_as(self.user_id)
# Mark test as online
- self.presence_handler.set_state(
- UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+ )
)
# Mark test2 as online, test will be offline with a last_active of 0.
# Note we don't join them to the room yet
- self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ self.get_success(
+ self.presence_handler.set_state(
+ UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ )
)
# Add servers to the room
@@ -540,7 +585,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
- destinations=set(("server2", "server3")), states=[expected_state]
+ destinations={"server2", "server3"}, states=[expected_state]
)
def _add_new_user(self, room_id, user_id):
@@ -549,7 +594,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
hostname = get_domain_from_id(user_id)
- room_version = self.get_success(self.store.get_room_version(room_id))
+ room_version = self.get_success(self.store.get_room_version_id(room_id))
builder = EventBuilder(
state=self.state,
@@ -558,7 +603,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
clock=self.clock,
hostname=hostname,
signing_key=self.random_signing_key,
- format_version=room_version_to_event_format(room_version),
+ room_version=KNOWN_ROOM_VERSIONS[room_version],
room_id=room_id,
type=EventTypes.Member,
sender=user_id,
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 45cbfeb9a4..2311040201 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -71,9 +71,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(
- self.frank.localpart, "Frank", 1,
- )
+ yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
displayname = yield self.handler.get_displayname(self.frank)
@@ -127,7 +125,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png", 1,
+ self.frank.localpart, "http://my.server/me.png", 1
)
avatar_url = yield self.handler.get_avatar_url(self.frank)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b2aa5c2478..5e7f14a3d5 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.constants import UserTypes
-from synapse.api.errors import ResourceLimitError, SynapseError
+from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
from synapse.rest.client.v2_alpha.register import _map_email_to_displayname
from synapse.types import RoomAlias, UserID, create_requester
@@ -45,7 +45,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
hs_config["max_mau_value"] = 50
hs_config["limit_usage_by_mau"] = True
- hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True)
+ hs = self.setup_test_homeserver(config=hs_config)
return hs
def prepare(self, reactor, clock, hs):
@@ -53,7 +53,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.mock_distributor.declare("registered_user")
self.mock_captcha_client = Mock()
self.macaroon_generator = Mock(
- generate_access_token=Mock(return_value='secret')
+ generate_access_token=Mock(return_value="secret")
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.handler = self.hs.get_registration_handler()
@@ -68,27 +68,23 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = frank.to_string()
requester = create_requester(user_id)
result_user_id, result_token = self.get_success(
- self.handler.get_or_create_user(requester, frank.localpart, "Frankie")
+ self.get_or_create_user(requester, frank.localpart, "Frankie")
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
- self.assertEquals(result_token, 'secret')
+ self.assertEquals(result_token, "secret")
def test_if_user_exists(self):
store = self.hs.get_datastore()
frank = UserID.from_string("@frank:test")
self.get_success(
- store.register(
- user_id=frank.to_string(),
- token="jkv;g498752-43gj['eamb!-5",
- password_hash=None,
- )
+ store.register_user(user_id=frank.to_string(), password_hash=None)
)
local_part = frank.localpart
user_id = frank.to_string()
requester = create_requester(user_id)
result_user_id, result_token = self.get_success(
- self.handler.get_or_create_user(requester, local_part, None)
+ self.get_or_create_user(requester, local_part, None)
)
self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None)
@@ -96,9 +92,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- self.get_success(
- self.handler.get_or_create_user(self.requester, 'a', "display_name")
- )
+ self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True
@@ -106,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User"))
+ self.get_success(self.get_or_create_user(self.requester, "c", "User"))
def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True
@@ -114,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.lots_of_users)
)
self.get_failure(
- self.handler.get_or_create_user(self.requester, 'b', "display_name"),
+ self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
@@ -122,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
self.get_failure(
- self.handler.get_or_create_user(self.requester, 'b', "display_name"),
+ self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
@@ -132,64 +126,89 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.lots_of_users)
)
self.get_failure(
- self.handler.register(localpart="local_part"), ResourceLimitError
+ self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
self.get_failure(
- self.handler.register(localpart="local_part"), ResourceLimitError
+ self.handler.register_user(localpart="local_part"), ResourceLimitError
)
def test_auto_create_auto_join_rooms(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- res = self.get_success(self.handler.register(localpart='jeff'))
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
- self.assertTrue(room_id['room_id'] in rooms)
+ self.assertTrue(room_id["room_id"] in rooms)
self.assertEqual(len(rooms), 1)
def test_auto_create_auto_join_rooms_with_no_rooms(self):
self.hs.config.auto_join_rooms = []
frank = UserID.from_string("@frank:test")
- res = self.get_success(self.handler.register(frank.localpart))
- self.assertEqual(res[0], frank.to_string())
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(frank.localpart))
+ self.assertEqual(user_id, frank.to_string())
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
def test_auto_create_auto_join_where_room_is_another_domain(self):
self.hs.config.auto_join_rooms = ["#room:another"]
frank = UserID.from_string("@frank:test")
- res = self.get_success(self.handler.register(frank.localpart))
- self.assertEqual(res[0], frank.to_string())
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(frank.localpart))
+ self.assertEqual(user_id, frank.to_string())
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
def test_auto_create_auto_join_where_auto_create_is_false(self):
self.hs.config.autocreate_auto_join_rooms = False
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- res = self.get_success(self.handler.register(localpart='jeff'))
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
- def test_auto_create_auto_join_rooms_when_support_user_exists(self):
+ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_support_user = Mock(return_value=True)
- res = self.get_success(self.handler.register(localpart='support'))
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ self.store.is_real_user = Mock(return_value=False)
+ user_id = self.get_success(self.handler.register_user(localpart="support"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
directory_handler = self.hs.get_handlers().directory_handler
room_alias = RoomAlias.from_string(room_alias_str)
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
+ def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+
+ self.store.count_real_users = Mock(return_value=1)
+ self.store.is_real_user = Mock(return_value=True)
+ user_id = self.get_success(self.handler.register_user(localpart="real"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = self.get_success(directory_handler.get_association(room_alias))
+
+ self.assertTrue(room_id["room_id"] in rooms)
+ self.assertEqual(len(rooms), 1)
+
+ def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+
+ self.store.count_real_users = Mock(return_value=2)
+ self.store.is_real_user = Mock(return_value=True)
+ user_id = self.get_success(self.handler.register_user(localpart="real"))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 0)
+
def test_auto_create_auto_join_where_no_consent(self):
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
@@ -212,54 +231,100 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# When:-
# * the user is registered and post consent actions are called
- res = self.get_success(self.handler.register(localpart='jeff'))
- self.get_success(self.handler.post_consent_actions(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+ self.get_success(self.handler.post_consent_actions(user_id))
# Then:-
# * Ensure that they have not been joined to the room
- rooms = self.get_success(self.store.get_rooms_for_user(res[0]))
+ rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
def test_register_support_user(self):
- res = self.get_success(
- self.handler.register(localpart='user', user_type=UserTypes.SUPPORT)
+ user_id = self.get_success(
+ self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
)
- self.assertTrue(self.store.is_support_user(res[0]))
+ d = self.store.is_support_user(user_id)
+ self.assertTrue(self.get_success(d))
def test_register_not_support_user(self):
- res = self.get_success(self.handler.register(localpart='user'))
- self.assertFalse(self.store.is_support_user(res[0]))
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+ d = self.store.is_support_user(user_id)
+ self.assertFalse(self.get_success(d))
def test_invalid_user_id_length(self):
invalid_user_id = "x" * 256
self.get_failure(
- self.handler.register(localpart=invalid_user_id),
- SynapseError
+ self.handler.register_user(localpart=invalid_user_id), SynapseError
)
def test_email_to_displayname_mapping(self):
"""Test that custom emails are mapped to new user displaynames correctly"""
self._check_mapping(
- "jack-phillips.rivers@big-org.com",
- "Jack-Phillips Rivers [Big-Org]",
+ "jack-phillips.rivers@big-org.com", "Jack-Phillips Rivers [Big-Org]"
)
- self._check_mapping(
- "bob.jones@matrix.org",
- "Bob Jones [Tchap Admin]",
- )
+ self._check_mapping("bob.jones@matrix.org", "Bob Jones [Tchap Admin]")
- self._check_mapping(
- "bob-jones.blabla@gouv.fr",
- "Bob-Jones Blabla [Gouv]",
- )
+ self._check_mapping("bob-jones.blabla@gouv.fr", "Bob-Jones Blabla [Gouv]")
# Multibyte unicode characters
self._check_mapping(
- u"j\u030a\u0065an-poppy.seed@example.com",
- u"J\u030a\u0065an-Poppy Seed [Example]",
+ "j\u030a\u0065an-poppy.seed@example.com",
+ "J\u030a\u0065an-Poppy Seed [Example]",
)
def _check_mapping(self, i, expected):
result = _map_email_to_displayname(i)
self.assertEqual(result, expected)
+
+ @defer.inlineCallbacks
+ def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+ """Creates a new user if the user does not exist,
+ else revokes all previous access tokens and generates a new one.
+
+ XXX: this used to be in the main codebase, but was only used by this file,
+ so got moved here. TODO: get rid of it, probably
+
+ Args:
+ localpart : The local part of the user ID to register. If None,
+ one will be randomly generated.
+ Returns:
+ A tuple of (user_id, access_token).
+ """
+ if localpart is None:
+ raise SynapseError(400, "Request must include user id")
+ yield self.hs.get_auth().check_auth_blocking()
+ need_register = True
+
+ try:
+ yield self.handler.check_username(localpart)
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ need_register = False
+ else:
+ raise
+
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ token = self.macaroon_generator.generate_access_token(user_id)
+
+ if need_register:
+ yield self.handler.register_with_store(
+ user_id=user_id,
+ password_hash=password_hash,
+ create_profile_with_displayname=user.localpart,
+ )
+ else:
+ yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+
+ yield self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+
+ if displayname is not None:
+ # logger.info("setting user display name: %s -> %s", user_id, displayname)
+ yield self.hs.get_profile_handler().set_displayname(
+ user, requester, displayname, by_admin=True
+ )
+
+ return user_id, token
diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py
deleted file mode 100644
index 61eebb6985..0000000000
--- a/tests/handlers/test_roomlist.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.handlers.room_list import RoomListNextBatch
-
-import tests.unittest
-import tests.utils
-
-
-class RoomListTestCase(tests.unittest.TestCase):
- """ Tests RoomList's RoomListNextBatch. """
-
- def setUp(self):
- pass
-
- def test_check_read_batch_tokens(self):
- batch_token = RoomListNextBatch(
- stream_ordering="abcdef",
- public_room_stream_id="123",
- current_limit=20,
- direction_is_forward=True,
- ).to_token()
- next_batch = RoomListNextBatch.from_token(batch_token)
- self.assertEquals(next_batch.stream_ordering, "abcdef")
- self.assertEquals(next_batch.public_room_stream_id, "123")
- self.assertEquals(next_batch.current_limit, 20)
- self.assertEquals(next_batch.direction_is_forward, True)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 2710c991cf..8e6b0b7536 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
+from synapse.storage.data_stores.main import stats
from tests import unittest
+# The expected number of state events in a fresh public room.
+EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM = 5
+
+# The expected number of state events in a fresh private room.
+#
+# Note: we increase this by 1 on the dinsic branch as we send
+# a "im.vector.room.access_rules" state event into new private rooms
+EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 7
+
class StatsRoomTests(unittest.HomeserverTestCase):
@@ -33,7 +38,6 @@ class StatsRoomTests(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
-
self.store = hs.get_datastore()
self.handler = self.hs.get_stats_handler()
@@ -42,40 +46,84 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
- {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ {"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_rooms",
"progress_json": "{}",
- "depends_on": "populate_stats_createtables",
+ "depends_on": "populate_stats_prepare",
},
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_cleanup",
+ "update_name": "populate_stats_process_users",
"progress_json": "{}",
"depends_on": "populate_stats_process_rooms",
},
)
)
+ self.get_success(
+ self.store.db.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_users",
+ },
+ )
+ )
+
+ def get_all_room_state(self):
+ return self.store.db.simple_select_list(
+ "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
+ )
+
+ def _get_current_stats(self, stats_type, stat_id):
+ table, id_col = stats.TYPE_TO_TABLE[stats_type]
+
+ cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
+ stats.PER_SLICE_FIELDS[stats_type]
+ )
+
+ end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
+
+ return self.get_success(
+ self.store.db.simple_select_one(
+ table + "_historical",
+ {id_col: stat_id, end_ts: end_ts},
+ cols,
+ allow_none=True,
+ )
+ )
+
+ def _perform_background_initial_update(self):
+ # Do the initial population of the stats via the background update
+ self._add_background_updates()
+
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
def test_initial_room(self):
"""
The background updates will build the table from scratch.
"""
- r = self.get_success(self.store.get_all_room_state())
+ r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 0)
# Disable stats
@@ -91,7 +139,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
# Stats disabled, shouldn't have done anything
- r = self.get_success(self.store.get_all_room_state())
+ r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 0)
# Enable stats
@@ -101,10 +149,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
- r = self.get_success(self.store.get_all_room_state())
+ r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")
@@ -114,6 +166,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Ingestion via notify_new_event will ignore tokens that the background
update have already processed.
"""
+
self.reactor.advance(86401)
self.hs.config.stats_enabled = False
@@ -137,32 +190,44 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
- self.store._all_done = False
- self.get_success(self.store.update_stats_stream_pos(None))
+ self.store.db.updates._all_done = False
+ self.get_success(
+ self.store.db.simple_update_one(
+ table="stats_incremental_position",
+ keyvalues={},
+ updatevalues={"stream_id": 0},
+ )
+ )
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
- {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ {"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
# Now, before the table is actually ingested, add some more events.
self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
self.helper.join(room=room_1, user=u2, tok=u2_token)
+ # orig_delta_processor = self.store.
+
# Now do the initial ingestion.
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -172,9 +237,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- self.store._all_done = False
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ self.store.db.updates._all_done = False
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
self.reactor.advance(86401)
@@ -185,8 +254,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
self.helper.join(room=room_1, user=u3, tok=u3_token)
- # Get the deltas! There should be two -- day 1, and day 2.
- r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+ # self.handler.notify_new_event()
+
+ # We need to let the delta processor advance…
+ self.pump(10 * 60)
+
+ # Get the slices! There should be two -- day 1, and day 2.
+ r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
+
+ self.assertEqual(len(r), 2)
# The oldest has 2 joined members
self.assertEqual(r[-1]["joined_members"], 2)
@@ -194,114 +270,482 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# The newest has 3
self.assertEqual(r[0]["joined_members"], 3)
- def test_incorrect_state_transition(self):
- """
- If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to
- (JOIN, INVITE, LEAVE, BAN), an error is raised.
- """
- events = {
- "a1": {"membership": Membership.LEAVE},
- "a2": {"membership": "not a real thing"},
- }
-
- def get_event(event_id, allow_none=True):
- m = Mock()
- m.content = events[event_id]
- d = defer.Deferred()
- self.reactor.callLater(0.0, d.callback, m)
- return d
-
- def get_received_ts(event_id):
- return defer.succeed(1)
-
- self.store.get_received_ts = get_received_ts
- self.store.get_event = get_event
-
- deltas = [
- {
- "type": EventTypes.Member,
- "state_key": "some_user",
- "room_id": "room",
- "event_id": "a1",
- "prev_event_id": "a2",
- "stream_id": 60,
- }
- ]
-
- f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ def test_create_user(self):
+ """
+ When we create a user, it should have statistics already ready.
+ """
+
+ u1 = self.register_user("u1", "pass")
+
+ u1stats = self._get_current_stats("user", u1)
+
+ self.assertIsNotNone(u1stats)
+
+ # not in any rooms by default
+ self.assertEqual(u1stats["joined_rooms"], 0)
+
+ def test_create_room(self):
+ """
+ When we create a room, it should have statistics already ready.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+ r1stats = self._get_current_stats("room", r1)
+ r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
+ r2stats = self._get_current_stats("room", r2)
+
+ self.assertIsNotNone(r1stats)
+ self.assertIsNotNone(r2stats)
+
+ # contains the default things you'd expect in a fresh room
self.assertEqual(
- f.value.args[0], "'not a real thing' is not a valid prev_membership"
- )
-
- # And the other way...
- deltas = [
- {
- "type": EventTypes.Member,
- "state_key": "some_user",
- "room_id": "room",
- "event_id": "a2",
- "prev_event_id": "a1",
- "stream_id": 100,
- }
- ]
-
- f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ r1stats["total_events"],
+ EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM,
+ "Wrong number of total_events in new room's stats!"
+ " You may need to update this if more state events are added to"
+ " the room creation process.",
+ )
self.assertEqual(
- f.value.args[0], "'not a real thing' is not a valid membership"
+ r2stats["total_events"],
+ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM,
+ "Wrong number of total_events in new room's stats!"
+ " You may need to update this if more state events are added to"
+ " the room creation process.",
)
- def test_redacted_prev_event(self):
+ self.assertEqual(
+ r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
+ )
+ self.assertEqual(
+ r2stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM
+ )
+
+ self.assertEqual(r1stats["joined_members"], 1)
+ self.assertEqual(r1stats["invited_members"], 0)
+ self.assertEqual(r1stats["banned_members"], 0)
+
+ self.assertEqual(r2stats["joined_members"], 1)
+ self.assertEqual(r2stats["invited_members"], 0)
+ self.assertEqual(r2stats["banned_members"], 0)
+
+ def test_send_message_increments_total_events(self):
"""
- If the prev_event does not exist, then it is assumed to be a LEAVE.
+ When we send a message, it increments total_events.
"""
+
+ self._perform_background_initial_update()
+
u1 = self.register_user("u1", "pass")
- u1_token = self.login("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+ r1stats_ante = self._get_current_stats("room", r1)
- room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send(r1, "hiss", tok=u1token)
- # Do the initial population of the user directory via the background update
- self._add_background_updates()
+ r1stats_post = self._get_current_stats("room", r1)
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
-
- events = {
- "a1": None,
- "a2": {"membership": Membership.JOIN},
- }
-
- def get_event(event_id, allow_none=True):
- if events.get(event_id):
- m = Mock()
- m.content = events[event_id]
- else:
- m = None
- d = defer.Deferred()
- self.reactor.callLater(0.0, d.callback, m)
- return d
-
- def get_received_ts(event_id):
- return defer.succeed(1)
-
- self.store.get_received_ts = get_received_ts
- self.store.get_event = get_event
-
- deltas = [
- {
- "type": EventTypes.Member,
- "state_key": "some_user:test",
- "room_id": room_1,
- "event_id": "a2",
- "prev_event_id": "a1",
- "stream_id": 100,
- }
- ]
-
- # Handle our fake deltas, which has a user going from LEAVE -> JOIN.
- self.get_success(self.handler._handle_deltas(deltas))
-
- # One delta, with two joined members -- the room creator, and our fake
- # user.
- r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
- self.assertEqual(len(r), 1)
- self.assertEqual(r[0]["joined_members"], 2)
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+
+ def test_send_state_event_nonoverwriting(self):
+ """
+ When we send a non-overwriting state event, it increments total_events AND current_state_events
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby"
+ )
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
+ )
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 1,
+ )
+
+ def test_send_state_event_overwriting(self):
+ """
+ When we send an overwriting state event, it increments total_events ONLY
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": True}, tok=u1token, state_key="tabby"
+ )
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.send_state(
+ r1, "cat.hissing", {"value": False}, tok=u1token, state_key="tabby"
+ )
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+
+ def test_join_first_time(self):
+ """
+ When a user joins a room for the first time, total_events, current_state_events and
+ joined_members should increase by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 1,
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
+ )
+
+ def test_join_after_leave(self):
+ """
+ When a user joins a room after being previously left, total_events and
+ joined_members should increase by exactly 1.
+ current_state_events should not increase.
+ left_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.join(r1, u2, tok=u2token)
+ self.helper.leave(r1, u2, tok=u2token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["left_members"] - r1stats_ante["left_members"], -1
+ )
+
+ def test_invited(self):
+ """
+ When a user invites another user, current_state_events, total_events and
+ invited_members should increase by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.invite(r1, u1, u2, tok=u1token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 1,
+ )
+ self.assertEqual(
+ r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
+ )
+
+ def test_join_after_invite(self):
+ """
+ When a user joins a room after being invited, total_events and
+ joined_members should increase by exactly 1.
+ current_state_events should not increase.
+ invited_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.invite(r1, u1, u2, tok=u1token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
+ )
+
+ def test_left(self):
+ """
+ When a user leaves a room after joining, total_events and
+ left_members should increase by exactly 1.
+ current_state_events should not increase.
+ joined_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.leave(r1, u2, tok=u2token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["left_members"] - r1stats_ante["left_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
+ )
+
+ def test_banned(self):
+ """
+ When a user is banned from a room after joining, total_events and
+ left_members should increase by exactly 1.
+ current_state_events should not increase.
+ banned_members should decrease by exactly 1.
+ """
+
+ self._perform_background_initial_update()
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+
+ self.helper.join(r1, u2, tok=u2token)
+
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
+
+ r1stats_post = self._get_current_stats("room", r1)
+
+ self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ self.assertEqual(
+ r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
+ 0,
+ )
+ self.assertEqual(
+ r1stats_post["banned_members"] - r1stats_ante["banned_members"], +1
+ )
+ self.assertEqual(
+ r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
+ )
+
+ def test_initial_background_update(self):
+ """
+ Test that statistics can be generated by the initial background update
+ handler.
+
+ This test also tests that stats rows are not created for new subjects
+ when stats are disabled. However, it may be desirable to change this
+ behaviour eventually to still keep current rows.
+ """
+
+ self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ # test that these subjects, which were created during a time of disabled
+ # stats, do not have stats.
+ self.assertIsNone(self._get_current_stats("room", r1))
+ self.assertIsNone(self._get_current_stats("user", u1))
+
+ self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
+
+ self._perform_background_initial_update()
+
+ r1stats = self._get_current_stats("room", r1)
+ u1stats = self._get_current_stats("user", u1)
+
+ self.assertEqual(r1stats["joined_members"], 1)
+ self.assertEqual(
+ r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
+ )
+
+ self.assertEqual(u1stats["joined_rooms"], 1)
+
+ def test_incomplete_stats(self):
+ """
+ This tests that we track incomplete statistics.
+
+ We first test that incomplete stats are incrementally generated,
+ following the preparation of a background regen.
+
+ We then test that these incomplete rows are completed by the background
+ regen.
+ """
+
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ u2 = self.register_user("u2", "pass")
+ u2token = self.login("u2", "pass")
+ u3 = self.register_user("u3", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
+
+ # preparation stage of the initial background update
+ # Ugh, have to reset this flag
+ self.store.db.updates._all_done = False
+
+ self.get_success(
+ self.store.db.simple_delete(
+ "room_stats_current", {"1": 1}, "test_delete_stats"
+ )
+ )
+ self.get_success(
+ self.store.db.simple_delete(
+ "user_stats_current", {"1": 1}, "test_delete_stats"
+ )
+ )
+
+ self.helper.invite(r1, u1, u2, tok=u1token)
+ self.helper.join(r1, u2, tok=u2token)
+ self.helper.invite(r1, u1, u3, tok=u1token)
+ self.helper.send(r1, "thou shalt yield", tok=u1token)
+
+ # now do the background updates
+
+ self.store.db.updates._all_done = False
+ self.get_success(
+ self.store.db.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_prepare",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_process_users",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db.simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_users",
+ },
+ )
+ )
+
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
+
+ r1stats_complete = self._get_current_stats("room", r1)
+ u1stats_complete = self._get_current_stats("user", u1)
+ u2stats_complete = self._get_current_stats("user", u2)
+
+ # now we make our assertions
+
+ # check that _complete rows are complete and correct
+ self.assertEqual(r1stats_complete["joined_members"], 2)
+ self.assertEqual(r1stats_complete["invited_members"], 1)
+
+ self.assertEqual(
+ r1stats_complete["current_state_events"],
+ 2 + EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM,
+ )
+
+ self.assertEqual(u1stats_complete["joined_rooms"], 1)
+ self.assertEqual(u2stats_complete["joined_rooms"], 1)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..4cbe9784ed 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
from synapse.types import UserID
import tests.unittest
import tests.utils
-from tests.utils import setup_test_homeserver
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
""" Tests Sync Handler. """
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.sync_handler = SyncHandler(self.hs)
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+ self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
- user_id1 = "@user1:server"
- user_id2 = "@user2:server"
+ user_id1 = "@user1:test"
+ user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
+ self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
- yield self.store.upsert_monthly_active_user(user_id1)
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
+ self.get_success(self.store.upsert_monthly_active_user(user_id1))
+ self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
self.hs.config.hs_disabled = True
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index cb8b4d2913..51e2b37218 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
+from tests.unittest import override_config
from tests.utils import register_federation_servlets
# Some local users to test with
@@ -47,7 +48,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
def _make_edu_transaction_json(edu_type, content):
- return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
+ return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
@@ -63,34 +64,36 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
+ datastores = Mock()
+ datastores.main = Mock(
+ spec=[
+ # Bits that Federation needs
+ "prep_send_transaction",
+ "delivered_txn",
+ "get_received_txn_response",
+ "set_received_txn_response",
+ "get_destination_retry_timings",
+ "get_devices_by_remote",
+ "maybe_store_room_on_invite",
+ # Bits that user_directory needs
+ "get_user_directory_stream_pos",
+ "get_current_state_deltas",
+ "get_device_updates_by_remote",
+ ]
+ )
+
+ # the tests assume that we are starting at unix time 1000
+ reactor.pump((1000,))
+
hs = self.setup_test_homeserver(
- datastore=(
- Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_retry_timings",
- "get_devices_by_remote",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- ]
- )
- ),
- notifier=Mock(),
- http_client=mock_federation_client,
- keyring=mock_keyring,
+ notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
)
+ hs.datastores = datastores
+
return hs
def prepare(self, reactor, clock, hs):
- # the tests assume that we are starting at unix time 1000
- reactor.pump((1000,))
-
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -99,12 +102,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.event_source = hs.get_event_sources().sources["typing"]
self.datastore = hs.get_datastore()
- retry_timings_res = {"destination": "", "retry_last_ts": 0, "retry_interval": 0}
+ retry_timings_res = {
+ "destination": "",
+ "retry_last_ts": 0,
+ "retry_interval": 0,
+ "failure_ts": None,
+ }
self.datastore.get_destination_retry_timings.return_value = defer.succeed(
retry_timings_res
)
- self.datastore.get_devices_by_remote.return_value = (0, [])
+ self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+ (0, [])
+ )
def get_received_txn_response(*args):
return defer.succeed(None)
@@ -113,19 +123,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_joined_room(room_id, user_id):
+ def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
- hs.get_auth().check_joined_room = check_joined_room
+ hs.get_auth().check_user_in_room = check_user_in_room
def get_joined_hosts_for_room(room_id):
- return set(member.domain for member in self.room_members)
+ return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_current_users_in_room(room_id):
- return set(str(u) for u in self.room_members)
+ return {str(u) for u in self.room_members}
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
@@ -134,11 +144,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
defer.succeed(1)
)
- self.datastore.get_current_state_deltas.return_value = None
+ self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
+ self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+ ([], 0)
+ )
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ None
+ )
def test_started_typing_local(self):
self.room_members = [U_APPLE, U_BANANA]
@@ -151,10 +166,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -166,6 +183,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_started_typing_remote_send(self):
self.room_members = [U_APPLE, U_ONION]
@@ -209,15 +227,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"typing": True,
},
),
- federation_auth_origin=b'farm',
+ federation_auth_origin=b"farm",
)
self.render(request)
self.assertEqual(channel.code, 200)
- self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -229,6 +249,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"send_federation": True})
def test_stopped_typing(self):
self.room_members = [U_APPLE, U_BANANA, U_ONION]
@@ -237,7 +258,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
member = RoomMember(ROOM_ID, U_APPLE.to_string())
self.handler._member_typing_until[member] = 1002000
- self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()])
+ self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()}
self.assertEquals(self.event_source.get_current_key(), 0)
@@ -247,7 +268,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with(
@@ -268,7 +289,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -285,11 +308,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -303,10 +328,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.reactor.pump([16])
- self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 2)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -320,11 +347,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])])
+ self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9021e647fe..7b92bdbc47 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -47,11 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_local_profile_change_with_support_user(self):
support_user_id = "@support:test"
self.get_success(
- self.store.register(
- user_id=support_user_id,
- token="123",
- password_hash=None,
- user_type=UserTypes.SUPPORT,
+ self.store.register_user(
+ user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
@@ -60,24 +57,21 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
profile = self.get_success(self.store.get_user_in_directory(support_user_id))
self.assertTrue(profile is None)
- display_name = 'display_name'
+ display_name = "display_name"
- profile_info = ProfileInfo(avatar_url='avatar_url', display_name=display_name)
- regular_user_id = '@regular:test'
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+ regular_user_id = "@regular:test"
self.get_success(
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
- self.assertTrue(profile['display_name'] == display_name)
+ self.assertTrue(profile["display_name"] == display_name)
def test_handle_user_deactivated_support_user(self):
s_user_id = "@support:test"
self.get_success(
- self.store.register(
- user_id=s_user_id,
- token="123",
- password_hash=None,
- user_type=UserTypes.SUPPORT,
+ self.store.register_user(
+ user_id=s_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
@@ -90,7 +84,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_user_deactivated_regular_user(self):
r_user_id = "@regular:test"
self.get_success(
- self.store.register(user_id=r_user_id, token="123", password_hash=None)
+ self.store.register_user(user_id=r_user_id, password_hash=None)
)
self.store.remove_from_user_dir = Mock()
self.get_success(self.handler.handle_user_deactivated(r_user_id))
@@ -120,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
public_users = self.get_users_in_public_rooms()
self.assertEqual(
- self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
)
self.assertEqual(public_users, [])
@@ -153,6 +147,98 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
+ def test_spam_checker(self):
+ """
+ A user which fails to the spam checks will not appear in search results.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ # We do not add users to the directory until they join a room.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Check we have populated the database correctly.
+ shares_private = self.get_users_who_share_private_rooms()
+ public_users = self.get_users_in_public_rooms()
+
+ self.assertEqual(
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ )
+ self.assertEqual(public_users, [])
+
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
+ # Configure a spam checker that does not filter any users.
+ spam_checker = self.hs.get_spam_checker()
+
+ class AllowAll(object):
+ def check_username_for_spam(self, user_profile):
+ # Allow all users.
+ return False
+
+ spam_checker.spam_checker = AllowAll()
+
+ # The results do not change:
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
+ # Configure a spam checker that filters all users.
+ class BlockAll(object):
+ def check_username_for_spam(self, user_profile):
+ # All users are spammy.
+ return True
+
+ spam_checker.spam_checker = BlockAll()
+
+ # User1 now gets no search results for any of the other users.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
+ def test_legacy_spam_checker(self):
+ """
+ A spam checker without the expected method should be ignored.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ # We do not add users to the directory until they join a room.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Check we have populated the database correctly.
+ shares_private = self.get_users_who_share_private_rooms()
+ public_users = self.get_users_in_public_rooms()
+
+ self.assertEqual(
+ self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+ )
+ self.assertEqual(public_users, [])
+
+ # Configure a spam checker.
+ spam_checker = self.hs.get_spam_checker()
+ # The spam checker doesn't need any methods, so create a bare object.
+ spam_checker.spam_checker = object()
+
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
def _compress_shared(self, shared):
"""
Compress a list of users who share rooms dicts to a list of tuples.
@@ -164,7 +250,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@@ -175,7 +261,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
- self.store._simple_select_list(
+ self.store.db.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@@ -187,10 +273,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store._all_done = False
+ self.store.db.updates._all_done = False
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@@ -199,7 +285,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@@ -209,7 +295,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@@ -219,7 +305,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store._simple_insert(
+ self.store.db.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@@ -261,19 +347,23 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()
# User 1 and User 2 are in the same public room
- self.assertEqual(set(public_users), set([(u1, room), (u2, room)]))
+ self.assertEqual(set(public_users), {(u1, room), (u2, room)})
# User 1 and User 3 share private rooms
self.assertEqual(
self._compress_shared(shares_private),
- set([(u1, u3, private_room), (u3, u1, private_room)]),
+ {(u1, u3, private_room), (u3, u1, private_room)},
)
def test_initial_share_all_users(self):
@@ -296,15 +386,19 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._add_background_updates()
- while not self.get_success(self.store.has_completed_background_updates()):
- self.get_success(self.store.do_next_background_update(100), by=0.1)
+ while not self.get_success(
+ self.store.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db.updates.do_next_background_update(100), by=0.1
+ )
shares_private = self.get_users_who_share_private_rooms()
public_users = self.get_users_in_public_rooms()
# No users share rooms
self.assertEqual(public_users, [])
- self.assertEqual(self._compress_shared(shares_private), set([]))
+ self.assertEqual(self._compress_shared(shares_private), set())
# Despite not sharing a room, search_all_users means we get a search
# result.
|