diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/config/test_room_directory.py | 67 | ||||
-rw-r--r-- | tests/handlers/test_directory.py | 48 | ||||
-rw-r--r-- | tests/handlers/test_register.py | 75 | ||||
-rw-r--r-- | tests/push/__init__.py | 0 | ||||
-rw-r--r-- | tests/push/test_email.py | 148 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 106 | ||||
-rw-r--r-- | tests/scripts/__init__.py | 0 | ||||
-rw-r--r-- | tests/scripts/test_new_matrix_user.py | 160 | ||||
-rw-r--r-- | tests/server.py | 4 | ||||
-rw-r--r-- | tests/server_notices/test_resource_limits_server_notices.py | 10 | ||||
-rw-r--r-- | tests/storage/test_monthly_active_users.py | 10 | ||||
-rw-r--r-- | tests/storage/test_state.py | 175 | ||||
-rw-r--r-- | tests/test_mau.py | 2 | ||||
-rw-r--r-- | tests/unittest.py | 9 | ||||
-rw-r--r-- | tests/utils.py | 1 |
15 files changed, 716 insertions, 99 deletions
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py new file mode 100644 index 0000000000..f37a17d618 --- /dev/null +++ b/tests/config/test_room_directory.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +from synapse.config.room_directory import RoomDirectoryConfig + +from tests import unittest + + +class RoomDirectoryConfigTestCase(unittest.TestCase): + def test_alias_creation_acl(self): + config = yaml.load(""" + alias_creation_rules: + - user_id: "*bob*" + alias: "*" + action: "deny" + - user_id: "*" + alias: "#unofficial_*" + action: "allow" + - user_id: "@foo*:example.com" + alias: "*" + action: "allow" + - user_id: "@gah:example.com" + alias: "#goo:example.com" + action: "allow" + """) + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@bob:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#unofficial_st:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@foobar:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@gah:example.com", + alias="#goo:example.com", + )) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#test:example.com", + )) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ec7355688b..8ae6556c0a 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,7 +18,9 @@ from mock import Mock from twisted.internet import defer +from synapse.config.room_directory import RoomDirectoryConfig from synapse.handlers.directory import DirectoryHandler +from synapse.rest.client.v1 import directory, room from synapse.types import RoomAlias from tests import unittest @@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase): ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestCreateAliasACL(unittest.HomeserverTestCase): + user_id = "@test:test" + + servlets = [directory.register_servlets, room.register_servlets] + + def prepare(self, hs, reactor, clock): + # We cheekily override the config to add custom alias creation rules + config = {} + config["alias_creation_rules"] = [ + { + "user_id": "*", + "alias": "#unofficial_*", + "action": "allow", + } + ] + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed + + return hs + + def test_denied(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + + def test_allowed(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23unofficial_test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 7b4ade3dfb..3e9a190727 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.errors import ResourceLimitError from synapse.handlers.register import RegistrationHandler -from synapse.types import UserID, create_requester +from synapse.types import RoomAlias, UserID, create_requester from tests.utils import setup_test_homeserver @@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase): self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( self.addCleanup, - handlers=None, - http_client=None, expire_access_token=True, - profile_handler=Mock(), ) self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) - self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.store = self.hs.get_datastore() self.hs.config.max_mau_value = 50 self.lots_of_users = 100 self.small_number_of_users = 1 + self.requester = create_requester("@requester:test") + @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - local_part = "someone" - display_name = "someone" - user_id = "@someone:test" - requester = create_requester("@as:test") + frank = UserID.from_string("@frank:test") + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, frank.localpart, "Frankie" ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase): token="jkv;g498752-43gj['eamb!-5", password_hash=None, ) - local_part = "frank" - display_name = "Frank" - user_id = "@frank:test" - requester = create_requester("@as:test") + local_part = frank.localpart + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, local_part, None ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase): def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user("requester", 'a', "display_name") + yield self.handler.get_or_create_user(self.requester, 'a', "display_name") @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): @@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user("@user:server", 'c', "User") + yield self.handler.get_or_create_user(self.requester, 'c', "User") @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): @@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.lots_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") @defer.inlineCallbacks def test_register_mau_blocked(self): @@ -147,3 +143,44 @@ class RegistrationTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) + + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = yield directory_handler.get_association(room_alias) + + self.assertTrue(room_id['room_id'] in rooms) + self.assertEqual(len(rooms), 1) + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms_with_no_rooms(self): + self.hs.config.auto_join_rooms = [] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + + @defer.inlineCallbacks + def test_auto_create_auto_join_where_room_is_another_domain(self): + self.hs.config.auto_join_rooms = ["#room:another"] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + + @defer.inlineCallbacks + def test_auto_create_auto_join_where_auto_create_is_false(self): + self.hs.config.autocreate_auto_join_rooms = False + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) diff --git a/tests/push/__init__.py b/tests/push/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/push/__init__.py diff --git a/tests/push/test_email.py b/tests/push/test_email.py new file mode 100644 index 0000000000..50ee6910d1 --- /dev/null +++ b/tests/push/test_email.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pkg_resources + +from twisted.internet.defer import Deferred + +from synapse.rest.client.v1 import admin, login, room + +from tests.unittest import HomeserverTestCase + +try: + from synapse.push.mailer import load_jinja2_templates +except Exception: + load_jinja2_templates = None + + +class EmailPusherTests(HomeserverTestCase): + + skip = "No Jinja installed" if not load_jinja2_templates else None + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + config = self.default_config() + config.email_enable_notifs = True + config.start_pushers = True + + config.email_template_dir = os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ) + config.email_notif_template_html = "notif_mail.html" + config.email_notif_template_text = "notif_mail.txt" + config.email_smtp_host = "127.0.0.1" + config.email_smtp_port = 20 + config.require_transport_security = False + config.email_smtp_user = None + config.email_app_name = "Matrix" + config.email_notif_from = "test@example.com" + + hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + + return hs + + def test_sends_email(self): + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name="a@example.com", + pushkey="a@example.com", + lang=None, + data={}, + ) + ) + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Get the stream ordering before it gets sent + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + last_stream_ordering = pushers[0]["last_stream_ordering"] + + # Advance time a bit, so the pusher will register something has happened + self.pump(100) + + # It hasn't succeeded yet, so the stream ordering shouldn't have moved + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + + # One email was attempted to be sent + self.assertEqual(len(self.email_attempts), 1) + + # Make the email succeed + self.email_attempts[0][0].callback(True) + self.pump() + + # One email was attempted to be sent + self.assertEqual(len(self.email_attempts), 1) + + # The stream ordering has increased + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 359f7777ff..a824be9a62 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -23,7 +23,7 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer from synapse.api.constants import Membership -from synapse.rest.client.v1 import room +from synapse.rest.client.v1 import admin, login, room from tests import unittest @@ -799,3 +799,107 @@ class RoomMessageListTestCase(RoomBase): self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + + +class RoomSearchTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def prepare(self, reactor, clock, hs): + + # Register the user who does the searching + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register the user who sends the message + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + # Invite the other person + self.helper.invite( + room=self.room, + src=self.user_id, + tok=self.access_token, + targ=self.other_user_id, + ) + + # The other user joins + self.helper.join( + room=self.room, user=self.other_user_id, tok=self.other_access_token + ) + + def test_finds_message(self): + """ + The search functionality will search for content in messages if asked to + do so. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": {"keys": ["content.body"], "search_term": "Hi"} + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # No context was requested, so we should get none. + self.assertEqual(results["results"][0]["context"], {}) + + def test_include_context(self): + """ + When event_context includes include_profile, profile information will be + included in the search response. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": { + "keys": ["content.body"], + "search_term": "Hi", + "event_context": {"include_profile": True}, + } + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # We should get context info, like the two users, and the display names. + context = results["results"][0]["context"] + self.assertEqual(len(context["profile_info"].keys()), 2) + self.assertEqual( + context["profile_info"][self.other_user_id]["displayname"], "otheruser" + ) diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/scripts/__init__.py diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py new file mode 100644 index 0000000000..6f56893f5e --- /dev/null +++ b/tests/scripts/test_new_matrix_user.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from synapse._scripts.register_new_matrix_user import request_registration + +from tests.unittest import TestCase + + +class RegisterTestCase(TestCase): + def test_success(self): + """ + The script will fetch a nonce, and then generate a MAC with it, and then + post that MAC. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 200 + r.json = lambda: {"nonce": "a"} + return r + + def post(url, json=None, verify=None): + # Make sure we are sent the correct info + self.assertEqual(json["username"], "user") + self.assertEqual(json["password"], "pass") + self.assertEqual(json["nonce"], "a") + # We want a 40-char hex MAC + self.assertEqual(len(json["mac"]), 40) + + r = Mock() + r.status_code = 200 + return r + + requests = Mock() + requests.get = get + requests.post = post + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # We should get the success message making sure everything is OK. + self.assertIn("Success!", out) + + # sys.exit shouldn't have been called. + self.assertEqual(err_code, []) + + def test_failure_nonce(self): + """ + If the script fails to fetch a nonce, it throws an error and quits. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 404 + r.reason = "Not Found" + r.json = lambda: {"not": "error"} + return r + + requests = Mock() + requests.get = get + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # Exit was called + self.assertEqual(err_code, [1]) + + # We got an error message + self.assertIn("ERROR! Received 404 Not Found", out) + self.assertNotIn("Success!", out) + + def test_failure_post(self): + """ + The script will fetch a nonce, and then if the final POST fails, will + report an error and quit. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 200 + r.json = lambda: {"nonce": "a"} + return r + + def post(url, json=None, verify=None): + # Make sure we are sent the correct info + self.assertEqual(json["username"], "user") + self.assertEqual(json["password"], "pass") + self.assertEqual(json["nonce"], "a") + # We want a 40-char hex MAC + self.assertEqual(len(json["mac"]), 40) + + r = Mock() + # Then 500 because we're jerks + r.status_code = 500 + r.reason = "Broken" + return r + + requests = Mock() + requests.get = get + requests.post = post + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # Exit was called + self.assertEqual(err_code, [1]) + + # We got an error message + self.assertIn("ERROR! Received 500 Broken", out) + self.assertNotIn("Success!", out) diff --git a/tests/server.py b/tests/server.py index 7bee58dff1..819c854448 100644 --- a/tests/server.py +++ b/tests/server.py @@ -125,7 +125,9 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe req.content = BytesIO(content) if access_token: - req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) + req.requestHeaders.addRawHeader( + b"Authorization", b"Bearer " + access_token.encode('ascii') + ) if content: req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 4701eedd45..b1551df7ca 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -4,7 +4,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError -from synapse.handlers.auth import AuthHandler from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) @@ -13,17 +12,10 @@ from tests import unittest from tests.utils import setup_test_homeserver -class AuthHandlers(object): - def __init__(self, hs): - self.auth_handler = AuthHandler(hs) - - class TestResourceLimitsServerNotices(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) - self.hs.handlers = AuthHandlers(self.hs) - self.auth_handler = self.hs.handlers.auth_handler + self.hs = yield setup_test_homeserver(self.addCleanup) self.server_notices_sender = self.hs.get_server_notices_sender() # relying on [1] is far from ideal, but the only case where diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 686f12a0dc..832e379a83 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -52,7 +52,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.initialise_reserved_users(threepids) + + self.store.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) self.pump() active_count = self.store.get_monthly_active_count() @@ -199,7 +202,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): {'medium': 'email', 'address': user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.initialise_reserved_users(threepids) + self.store.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + self.pump() count = self.store.get_registered_reserved_users_count() self.assertEquals(self.get_success(count), 0) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index b9c5b39d59..086a39d834 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID import tests.unittest @@ -148,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we get the full state as of the final event state = yield self.store.get_state_for_event( - e5.event_id, None, filtered_types=None + e5.event_id, ) self.assertIsNotNone(e4) @@ -166,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, '')], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Member, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) self.assertStateMapEqual( {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filtered_types to grab a specific room member - # without filtering out the other event types + # check we can grab a specific room member without filtering out the + # other event types state = yield self.store.get_state_for_event( e5.event_id, - [(EventTypes.Member, self.u_alice.to_string())], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ) ) self.assertStateMapEqual( @@ -204,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase): state, ) - # check that types=[], filtered_types=[EventTypes.Member] - # doesn't return all members + # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, [], filtered_types=[EventTypes.Member] + e5.event_id, state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertStateMapEqual( @@ -215,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ####################################################### - # _get_some_state_from_cache tests against a full cache + # _get_state_for_group_using_cache tests against a full cache ####################################################### room_id = self.room.to_string() group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group = list(group_ids.keys())[0] - # test _get_some_state_from_cache correctly filters out members with types=[] - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -236,22 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members with wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -263,11 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -280,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -297,23 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) @@ -357,42 +381,54 @@ class StateStoreTestCase(tests.unittest.TestCase): ############################################ # test that things work with a partial cache - # test _get_some_state_from_cache correctly filters out members with types=[] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, None)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -404,44 +440,53 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=None, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) diff --git a/tests/test_mau.py b/tests/test_mau.py index bdbacb8448..5d387851c5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -207,7 +207,7 @@ class TestMauLimit(unittest.TestCase): def do_sync_for_user(self, token): request, channel = make_request( - "GET", "/sync", access_token=token.encode('ascii') + "GET", "/sync", access_token=token ) render(request, self.resource, self.reactor) diff --git a/tests/unittest.py b/tests/unittest.py index a59291cc60..4d40bdb6a5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -146,6 +146,13 @@ def DEBUG(target): return target +def INFO(target): + """A decorator to set the .loglevel attribute to logging.INFO. + Can apply to either a TestCase or an individual test method.""" + target.loglevel = logging.INFO + return target + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. @@ -373,5 +380,5 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 200) - access_token = channel.json_body["access_token"].encode('ascii') + access_token = channel.json_body["access_token"] return access_token diff --git a/tests/utils.py b/tests/utils.py index dd347a0c59..565bb60d08 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -124,6 +124,7 @@ def default_config(name): config.user_consent_server_notice_content = None config.block_events_without_consent_error = None config.media_storage_providers = [] + config.autocreate_auto_join_rooms = True config.auto_join_rooms = [] config.limit_usage_by_mau = False config.hs_disabled = False |