diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 48b2d3d663..2a7044801a 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -60,7 +60,7 @@ class FilteringTestCase(unittest.TestCase):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
- {"event_fields": ["\\foo"]},
+ {"event_fields": [r"\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
@@ -109,6 +109,16 @@ class FilteringTestCase(unittest.TestCase):
"event_format": "client",
"event_fields": ["type", "content", "sender"],
},
+
+ # a single backslash should be permitted (though it is debatable whether
+ # it should be permitted before anything other than `.`, and what that
+ # actually means)
+ #
+ # (note that event_fields is implemented in
+ # synapse.events.utils.serialize_event, and so whether this actually works
+ # is tested elsewhere. We just want to check that it is allowed through the
+ # filter validation)
+ {"event_fields": [r"foo\.bar"]},
]
for filter in valid_filters:
try:
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index f88d28a19d..0c23068bcf 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -67,6 +67,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
with open(log_config_file) as f:
config = f.read()
# find the 'filename' line
- matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
+ matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M)
self.assertEqual(1, len(matches))
self.assertEqual(matches[0], expected)
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
new file mode 100644
index 0000000000..f37a17d618
--- /dev/null
+++ b/tests/config/test_room_directory.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import yaml
+
+from synapse.config.room_directory import RoomDirectoryConfig
+
+from tests import unittest
+
+
+class RoomDirectoryConfigTestCase(unittest.TestCase):
+ def test_alias_creation_acl(self):
+ config = yaml.load("""
+ alias_creation_rules:
+ - user_id: "*bob*"
+ alias: "*"
+ action: "deny"
+ - user_id: "*"
+ alias: "#unofficial_*"
+ action: "allow"
+ - user_id: "@foo*:example.com"
+ alias: "*"
+ action: "allow"
+ - user_id: "@gah:example.com"
+ alias: "#goo:example.com"
+ action: "allow"
+ """)
+
+ rd_config = RoomDirectoryConfig()
+ rd_config.read_config(config)
+
+ self.assertFalse(rd_config.is_alias_creation_allowed(
+ user_id="@bob:example.com",
+ alias="#test:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com",
+ alias="#unofficial_st:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@foobar:example.com",
+ alias="#test:example.com",
+ ))
+
+ self.assertTrue(rd_config.is_alias_creation_allowed(
+ user_id="@gah:example.com",
+ alias="#goo:example.com",
+ ))
+
+ self.assertFalse(rd_config.is_alias_creation_allowed(
+ user_id="@test:example.com",
+ alias="#test:example.com",
+ ))
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index ff217ca8b9..d0cc492deb 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -156,7 +156,7 @@ class SerializeEventTestCase(unittest.TestCase):
room_id="!foo:bar",
content={"key.with.dots": {}},
),
- ["content.key\.with\.dots"],
+ [r"content.key\.with\.dots"],
),
{"content": {"key.with.dots": {}}},
)
@@ -172,7 +172,7 @@ class SerializeEventTestCase(unittest.TestCase):
"nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
},
),
- ["content.nested\.dot\.key.leaf\.key"],
+ [r"content.nested\.dot\.key.leaf\.key"],
),
{"content": {"nested.dot.key": {"leaf.key": 42}}},
)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ec7355688b..8ae6556c0a 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -18,7 +18,9 @@ from mock import Mock
from twisted.internet import defer
+from synapse.config.room_directory import RoomDirectoryConfig
from synapse.handlers.directory import DirectoryHandler
+from synapse.rest.client.v1 import directory, room
from synapse.types import RoomAlias
from tests import unittest
@@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase):
)
self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+
+
+class TestCreateAliasACL(unittest.HomeserverTestCase):
+ user_id = "@test:test"
+
+ servlets = [directory.register_servlets, room.register_servlets]
+
+ def prepare(self, hs, reactor, clock):
+ # We cheekily override the config to add custom alias creation rules
+ config = {}
+ config["alias_creation_rules"] = [
+ {
+ "user_id": "*",
+ "alias": "#unofficial_*",
+ "action": "allow",
+ }
+ ]
+
+ rd_config = RoomDirectoryConfig()
+ rd_config.read_config(config)
+
+ self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed
+
+ return hs
+
+ def test_denied(self):
+ room_id = self.helper.create_room_as(self.user_id)
+
+ request, channel = self.make_request(
+ "PUT",
+ b"directory/room/%23test%3Atest",
+ ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ )
+ self.render(request)
+ self.assertEquals(403, channel.code, channel.result)
+
+ def test_allowed(self):
+ room_id = self.helper.create_room_as(self.user_id)
+
+ request, channel = self.make_request(
+ "PUT",
+ b"directory/room/%23unofficial_test%3Atest",
+ ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 7b4ade3dfb..3e9a190727 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import ResourceLimitError
from synapse.handlers.register import RegistrationHandler
-from synapse.types import UserID, create_requester
+from synapse.types import RoomAlias, UserID, create_requester
from tests.utils import setup_test_homeserver
@@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase):
self.mock_captcha_client = Mock()
self.hs = yield setup_test_homeserver(
self.addCleanup,
- handlers=None,
- http_client=None,
expire_access_token=True,
- profile_handler=Mock(),
)
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
- self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.store = self.hs.get_datastore()
self.hs.config.max_mau_value = 50
self.lots_of_users = 100
self.small_number_of_users = 1
+ self.requester = create_requester("@requester:test")
+
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
- local_part = "someone"
- display_name = "someone"
- user_id = "@someone:test"
- requester = create_requester("@as:test")
+ frank = UserID.from_string("@frank:test")
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, frank.localpart, "Frankie"
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase):
token="jkv;g498752-43gj['eamb!-5",
password_hash=None,
)
- local_part = "frank"
- display_name = "Frank"
- user_id = "@frank:test"
- requester = create_requester("@as:test")
+ local_part = frank.localpart
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, local_part, None
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase):
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.handler.get_or_create_user("requester", 'a', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'a', "display_name")
@defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self):
@@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- yield self.handler.get_or_create_user("@user:server", 'c', "User")
+ yield self.handler.get_or_create_user(self.requester, 'c', "User")
@defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self):
@@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
@defer.inlineCallbacks
def test_register_mau_blocked(self):
@@ -147,3 +143,44 @@ class RegistrationTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part")
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = yield directory_handler.get_association(room_alias)
+
+ self.assertTrue(room_id['room_id'] in rooms)
+ self.assertEqual(len(rooms), 1)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms_with_no_rooms(self):
+ self.hs.config.auto_join_rooms = []
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_room_is_another_domain(self):
+ self.hs.config.auto_join_rooms = ["#room:another"]
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_auto_create_is_false(self):
+ self.hs.config.autocreate_auto_join_rooms = False
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/scripts/__init__.py
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
new file mode 100644
index 0000000000..6f56893f5e
--- /dev/null
+++ b/tests/scripts/test_new_matrix_user.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from synapse._scripts.register_new_matrix_user import request_registration
+
+from tests.unittest import TestCase
+
+
+class RegisterTestCase(TestCase):
+ def test_success(self):
+ """
+ The script will fetch a nonce, and then generate a MAC with it, and then
+ post that MAC.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 200
+ r.json = lambda: {"nonce": "a"}
+ return r
+
+ def post(url, json=None, verify=None):
+ # Make sure we are sent the correct info
+ self.assertEqual(json["username"], "user")
+ self.assertEqual(json["password"], "pass")
+ self.assertEqual(json["nonce"], "a")
+ # We want a 40-char hex MAC
+ self.assertEqual(len(json["mac"]), 40)
+
+ r = Mock()
+ r.status_code = 200
+ return r
+
+ requests = Mock()
+ requests.get = get
+ requests.post = post
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # We should get the success message making sure everything is OK.
+ self.assertIn("Success!", out)
+
+ # sys.exit shouldn't have been called.
+ self.assertEqual(err_code, [])
+
+ def test_failure_nonce(self):
+ """
+ If the script fails to fetch a nonce, it throws an error and quits.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 404
+ r.reason = "Not Found"
+ r.json = lambda: {"not": "error"}
+ return r
+
+ requests = Mock()
+ requests.get = get
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # Exit was called
+ self.assertEqual(err_code, [1])
+
+ # We got an error message
+ self.assertIn("ERROR! Received 404 Not Found", out)
+ self.assertNotIn("Success!", out)
+
+ def test_failure_post(self):
+ """
+ The script will fetch a nonce, and then if the final POST fails, will
+ report an error and quit.
+ """
+
+ def get(url, verify=None):
+ r = Mock()
+ r.status_code = 200
+ r.json = lambda: {"nonce": "a"}
+ return r
+
+ def post(url, json=None, verify=None):
+ # Make sure we are sent the correct info
+ self.assertEqual(json["username"], "user")
+ self.assertEqual(json["password"], "pass")
+ self.assertEqual(json["nonce"], "a")
+ # We want a 40-char hex MAC
+ self.assertEqual(len(json["mac"]), 40)
+
+ r = Mock()
+ # Then 500 because we're jerks
+ r.status_code = 500
+ r.reason = "Broken"
+ return r
+
+ requests = Mock()
+ requests.get = get
+ requests.post = post
+
+ # The fake stdout will be written here
+ out = []
+ err_code = []
+
+ request_registration(
+ "user",
+ "pass",
+ "matrix.org",
+ "shared",
+ admin=False,
+ requests=requests,
+ _print=out.append,
+ exit=err_code.append,
+ )
+
+ # Exit was called
+ self.assertEqual(err_code, [1])
+
+ # We got an error message
+ self.assertIn("ERROR! Received 500 Broken", out)
+ self.assertNotIn("Success!", out)
diff --git a/tests/state/__init__.py b/tests/state/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/state/__init__.py
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
new file mode 100644
index 0000000000..efd85ebe6c
--- /dev/null
+++ b/tests/state/test_v2.py
@@ -0,0 +1,663 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+
+from six.moves import zip
+
+import attr
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.event_auth import auth_types_for_event
+from synapse.events import FrozenEvent
+from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.types import EventID
+
+from tests import unittest
+
+ALICE = "@alice:example.com"
+BOB = "@bob:example.com"
+CHARLIE = "@charlie:example.com"
+EVELYN = "@evelyn:example.com"
+ZARA = "@zara:example.com"
+
+ROOM_ID = "!test:example.com"
+
+MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
+MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
+
+
+ORIGIN_SERVER_TS = 0
+
+
+class FakeEvent(object):
+ """A fake event we use as a convenience.
+
+ NOTE: Again as a convenience we use "node_ids" rather than event_ids to
+ refer to events. The event_id has node_id as localpart and example.com
+ as domain.
+ """
+ def __init__(self, id, sender, type, state_key, content):
+ self.node_id = id
+ self.event_id = EventID(id, "example.com").to_string()
+ self.sender = sender
+ self.type = type
+ self.state_key = state_key
+ self.content = content
+
+ def to_event(self, auth_events, prev_events):
+ """Given the auth_events and prev_events, convert to a Frozen Event
+
+ Args:
+ auth_events (list[str]): list of event_ids
+ prev_events (list[str]): list of event_ids
+
+ Returns:
+ FrozenEvent
+ """
+ global ORIGIN_SERVER_TS
+
+ ts = ORIGIN_SERVER_TS
+ ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1
+
+ event_dict = {
+ "auth_events": [(a, {}) for a in auth_events],
+ "prev_events": [(p, {}) for p in prev_events],
+ "event_id": self.node_id,
+ "sender": self.sender,
+ "type": self.type,
+ "content": self.content,
+ "origin_server_ts": ts,
+ "room_id": ROOM_ID,
+ }
+
+ if self.state_key is not None:
+ event_dict["state_key"] = self.state_key
+
+ return FrozenEvent(event_dict)
+
+
+# All graphs start with this set of events
+INITIAL_EVENTS = [
+ FakeEvent(
+ id="CREATE",
+ sender=ALICE,
+ type=EventTypes.Create,
+ state_key="",
+ content={"creator": ALICE},
+ ),
+ FakeEvent(
+ id="IMA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IPOWER",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100}},
+ ),
+ FakeEvent(
+ id="IJR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.PUBLIC},
+ ),
+ FakeEvent(
+ id="IMB",
+ sender=BOB,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMC",
+ sender=CHARLIE,
+ type=EventTypes.Member,
+ state_key=CHARLIE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMZ",
+ sender=ZARA,
+ type=EventTypes.Member,
+ state_key=ZARA,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="START",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="END",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+]
+
+INITIAL_EDGES = [
+ "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
+]
+
+
+class StateTestCase(unittest.TestCase):
+ def test_ban_vs_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="MA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content={"membership": Membership.JOIN},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "MA", "PA", "START"],
+ ["END", "PB", "PA"],
+ ]
+
+ expected_state_ids = ["PA", "MA", "MB"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_join_rule_evasion(self):
+ events = [
+ FakeEvent(
+ id="JR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rules": JoinRules.PRIVATE},
+ ),
+ FakeEvent(
+ id="ME",
+ sender=EVELYN,
+ type=EventTypes.Member,
+ state_key=EVELYN,
+ content={"membership": Membership.JOIN},
+ ),
+ ]
+
+ edges = [
+ ["END", "JR", "START"],
+ ["END", "ME", "START"],
+ ]
+
+ expected_state_ids = ["JR"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_offtopic_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PC",
+ sender=CHARLIE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 0,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "PC", "PB", "PA", "START"],
+ ["END", "PA"],
+ ]
+
+ expected_state_ids = ["PC"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_basic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["PA2", "T2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_reset(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "T2", "PA", "T1", "START"],
+ ["END", "T1"],
+ ]
+
+ expected_state_ids = ["T1", "MB", "PA"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MZ1",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="T4",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "MZ1", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["T4", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def do_check(self, events, edges, expected_state_ids):
+ """Take a list of events and edges and calculate the state of the
+ graph at END, and asserts it matches `expected_state_ids`
+
+ Args:
+ events (list[FakeEvent])
+ edges (list[list[str]]): A list of chains of event edges, e.g.
+ `[[A, B, C]]` are edges A->B and B->C.
+ expected_state_ids (list[str]): The expected state at END, (excluding
+ the keys that haven't changed since START).
+ """
+ # We want to sort the events into topological order for processing.
+ graph = {}
+
+ # node_id -> FakeEvent
+ fake_event_map = {}
+
+ for ev in itertools.chain(INITIAL_EVENTS, events):
+ graph[ev.node_id] = set()
+ fake_event_map[ev.node_id] = ev
+
+ for a, b in pairwise(INITIAL_EDGES):
+ graph[a].add(b)
+
+ for edge_list in edges:
+ for a, b in pairwise(edge_list):
+ graph[a].add(b)
+
+ # event_id -> FrozenEvent
+ event_map = {}
+ # node_id -> state
+ state_at_event = {}
+
+ # We copy the map as the sort consumes the graph
+ graph_copy = {k: set(v) for k, v in graph.items()}
+
+ for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e):
+ fake_event = fake_event_map[node_id]
+ event_id = fake_event.event_id
+
+ prev_events = list(graph[node_id])
+
+ if len(prev_events) == 0:
+ state_before = {}
+ elif len(prev_events) == 1:
+ state_before = dict(state_at_event[prev_events[0]])
+ else:
+ state_d = resolve_events_with_store(
+ [state_at_event[n] for n in prev_events],
+ event_map=event_map,
+ state_res_store=TestStateResolutionStore(event_map),
+ )
+
+ self.assertTrue(state_d.called)
+ state_before = state_d.result
+
+ state_after = dict(state_before)
+ if fake_event.state_key is not None:
+ state_after[(fake_event.type, fake_event.state_key)] = event_id
+
+ auth_types = set(auth_types_for_event(fake_event))
+
+ auth_events = []
+ for key in auth_types:
+ if key in state_before:
+ auth_events.append(state_before[key])
+
+ event = fake_event.to_event(auth_events, prev_events)
+
+ state_at_event[node_id] = state_after
+ event_map[event_id] = event
+
+ expected_state = {}
+ for node_id in expected_state_ids:
+ # expected_state_ids are node IDs rather than event IDs,
+ # so we have to convert
+ event_id = EventID(node_id, "example.com").to_string()
+ event = event_map[event_id]
+
+ key = (event.type, event.state_key)
+
+ expected_state[key] = event_id
+
+ start_state = state_at_event["START"]
+ end_state = {
+ key: value
+ for key, value in state_at_event["END"].items()
+ if key in expected_state or start_state.get(key) != value
+ }
+
+ self.assertEqual(expected_state, end_state)
+
+
+class LexicographicalTestCase(unittest.TestCase):
+ def test_simple(self):
+ graph = {
+ "l": {"o"},
+ "m": {"n", "o"},
+ "n": {"o"},
+ "o": set(),
+ "p": {"o"},
+ }
+
+ res = list(lexicographical_topological_sort(graph, key=lambda x: x))
+
+ self.assertEqual(["o", "l", "n", "m", "p"], res)
+
+
+def pairwise(iterable):
+ "s -> (s0,s1), (s1,s2), (s2, s3), ..."
+ a, b = itertools.tee(iterable)
+ next(b, None)
+ return zip(a, b)
+
+
+@attr.s
+class TestStateResolutionStore(object):
+ event_map = attr.ib()
+
+ def get_events(self, event_ids, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ """
+
+ return {
+ eid: self.event_map[eid]
+ for eid in event_ids
+ if eid in self.event_map
+ }
+
+ def get_auth_chain(self, event_ids):
+ """Gets the full auth chain for a set of events (including rejected
+ events).
+
+ Includes the given event IDs in the result.
+
+ Note that:
+ 1. All events must be state events.
+ 2. For v1 rooms this may not have the full auth chain in the
+ presence of rejected events
+
+ Args:
+ event_ids (list): The event IDs of the events to fetch the auth
+ chain for. Must be state events.
+
+ Returns:
+ Deferred[list[str]]: List of event IDs of the auth chain.
+ """
+
+ # Simple DFS for auth chain
+ result = set()
+ stack = list(event_ids)
+ while stack:
+ event_id = stack.pop()
+ if event_id in result:
+ continue
+
+ result.add(event_id)
+
+ event = self.event_map[event_id]
+ for aid, _ in event.auth_events:
+ stack.append(aid)
+
+ return list(result)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 686f12a0dc..832e379a83 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -52,7 +52,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
now = int(self.hs.get_clock().time_msec())
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
- self.store.initialise_reserved_users(threepids)
+
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
self.pump()
active_count = self.store.get_monthly_active_count()
@@ -199,7 +202,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
{'medium': 'email', 'address': user2_email},
]
self.hs.config.mau_limits_reserved_threepids = threepids
- self.store.initialise_reserved_users(threepids)
+ self.store.runInteraction(
+ "initialise", self.store._initialise_reserved_users, threepids
+ )
+
self.pump()
count = self.store.get_registered_reserved_users_count()
self.assertEquals(self.get_success(count), 0)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index b9c5b39d59..086a39d834 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -18,6 +18,7 @@ import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
import tests.unittest
@@ -148,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we get the full state as of the final event
state = yield self.store.get_state_for_event(
- e5.event_id, None, filtered_types=None
+ e5.event_id,
)
self.assertIsNotNone(e4)
@@ -166,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, '')], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Name, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event(
- e5.event_id, [(EventTypes.Member, None)], filtered_types=None
+ e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
self.assertStateMapEqual(
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
)
- # check we can use filtered_types to grab a specific room member
- # without filtering out the other event types
+ # check we can grab a specific room member without filtering out the
+ # other event types
state = yield self.store.get_state_for_event(
e5.event_id,
- [(EventTypes.Member, self.u_alice.to_string())],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {self.u_alice.to_string()}},
+ include_others=True,
+ )
)
self.assertStateMapEqual(
@@ -204,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase):
state,
)
- # check that types=[], filtered_types=[EventTypes.Member]
- # doesn't return all members
+ # check that we can grab everything except members
state = yield self.store.get_state_for_event(
- e5.event_id, [], filtered_types=[EventTypes.Member]
+ e5.event_id, state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertStateMapEqual(
@@ -215,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
#######################################################
- # _get_some_state_from_cache tests against a full cache
+ # _get_state_for_group_using_cache tests against a full cache
#######################################################
room_id = self.room.to_string()
group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
group = list(group_ids.keys())[0]
- # test _get_some_state_from_cache correctly filters out members with types=[]
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -236,22 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({}, state_dict)
- # test _get_some_state_from_cache correctly filters in members with wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -263,11 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -280,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -297,23 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=None,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
@@ -357,42 +381,54 @@ class StateStoreTestCase(tests.unittest.TestCase):
############################################
# test that things work with a partial cache
- # test _get_some_state_from_cache correctly filters out members with types=[]
+ # test _get_state_for_group_using_cache correctly filters out members
+ # with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
- self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+ self.store._state_group_cache, group,
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: set()},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({}, state_dict)
- # test _get_some_state_from_cache correctly filters in members wildcard types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # wildcard types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, None)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: None},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
@@ -404,44 +440,53 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- # test _get_some_state_from_cache correctly filters in members with specific types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=[EventTypes.Member],
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=True,
+ ),
)
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
- # test _get_some_state_from_cache correctly filters in members with specific types
- # and no filtered_types
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ # test _get_state_for_group_using_cache correctly filters in members
+ # with specific types
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=None,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
self.store._state_group_members_cache,
group,
- [(EventTypes.Member, e5.state_key)],
- filtered_types=None,
+ state_filter=StateFilter(
+ types={EventTypes.Member: {e5.state_key}},
+ include_others=False,
+ ),
)
self.assertEqual(is_all, True)
diff --git a/tests/utils.py b/tests/utils.py
index dd347a0c59..565bb60d08 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -124,6 +124,7 @@ def default_config(name):
config.user_consent_server_notice_content = None
config.block_events_without_consent_error = None
config.media_storage_providers = []
+ config.autocreate_auto_join_rooms = True
config.auto_join_rooms = []
config.limit_usage_by_mau = False
config.hs_disabled = False
|