diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 48b2d3d663..2a7044801a 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -60,7 +60,7 @@ class FilteringTestCase(unittest.TestCase):
invalid_filters = [
{"boom": {}},
{"account_data": "Hello World"},
- {"event_fields": ["\\foo"]},
+ {"event_fields": [r"\\foo"]},
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}},
@@ -109,6 +109,16 @@ class FilteringTestCase(unittest.TestCase):
"event_format": "client",
"event_fields": ["type", "content", "sender"],
},
+
+ # a single backslash should be permitted (though it is debatable whether
+ # it should be permitted before anything other than `.`, and what that
+ # actually means)
+ #
+ # (note that event_fields is implemented in
+ # synapse.events.utils.serialize_event, and so whether this actually works
+ # is tested elsewhere. We just want to check that it is allowed through the
+ # filter validation)
+ {"event_fields": [r"foo\.bar"]},
]
for filter in valid_filters:
try:
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index f88d28a19d..0c23068bcf 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -67,6 +67,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
with open(log_config_file) as f:
config = f.read()
# find the 'filename' line
- matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
+ matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M)
self.assertEqual(1, len(matches))
self.assertEqual(matches[0], expected)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index ff217ca8b9..d0cc492deb 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -156,7 +156,7 @@ class SerializeEventTestCase(unittest.TestCase):
room_id="!foo:bar",
content={"key.with.dots": {}},
),
- ["content.key\.with\.dots"],
+ [r"content.key\.with\.dots"],
),
{"content": {"key.with.dots": {}}},
)
@@ -172,7 +172,7 @@ class SerializeEventTestCase(unittest.TestCase):
"nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
},
),
- ["content.nested\.dot\.key.leaf\.key"],
+ [r"content.nested\.dot\.key.leaf\.key"],
),
{"content": {"nested.dot.key": {"leaf.key": 42}}},
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 7b4ade3dfb..3e9a190727 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import ResourceLimitError
from synapse.handlers.register import RegistrationHandler
-from synapse.types import UserID, create_requester
+from synapse.types import RoomAlias, UserID, create_requester
from tests.utils import setup_test_homeserver
@@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase):
self.mock_captcha_client = Mock()
self.hs = yield setup_test_homeserver(
self.addCleanup,
- handlers=None,
- http_client=None,
expire_access_token=True,
- profile_handler=Mock(),
)
self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
- self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.store = self.hs.get_datastore()
self.hs.config.max_mau_value = 50
self.lots_of_users = 100
self.small_number_of_users = 1
+ self.requester = create_requester("@requester:test")
+
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
- local_part = "someone"
- display_name = "someone"
- user_id = "@someone:test"
- requester = create_requester("@as:test")
+ frank = UserID.from_string("@frank:test")
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, frank.localpart, "Frankie"
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase):
token="jkv;g498752-43gj['eamb!-5",
password_hash=None,
)
- local_part = "frank"
- display_name = "Frank"
- user_id = "@frank:test"
- requester = create_requester("@as:test")
+ local_part = frank.localpart
+ user_id = frank.to_string()
+ requester = create_requester(user_id)
result_user_id, result_token = yield self.handler.get_or_create_user(
- requester, local_part, display_name
+ requester, local_part, None
)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')
@@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase):
def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.handler.get_or_create_user("requester", 'a', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'a', "display_name")
@defer.inlineCallbacks
def test_get_or_create_user_mau_not_blocked(self):
@@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1)
)
# Ensure does not throw exception
- yield self.handler.get_or_create_user("@user:server", 'c', "User")
+ yield self.handler.get_or_create_user(self.requester, 'c', "User")
@defer.inlineCallbacks
def test_get_or_create_user_mau_blocked(self):
@@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase):
return_value=defer.succeed(self.lots_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.handler.get_or_create_user("requester", 'b', "display_name")
+ yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
@defer.inlineCallbacks
def test_register_mau_blocked(self):
@@ -147,3 +143,44 @@ class RegistrationTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
yield self.handler.register_saml2(localpart="local_part")
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms(self):
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+
+ directory_handler = self.hs.get_handlers().directory_handler
+ room_alias = RoomAlias.from_string(room_alias_str)
+ room_id = yield directory_handler.get_association(room_alias)
+
+ self.assertTrue(room_id['room_id'] in rooms)
+ self.assertEqual(len(rooms), 1)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_rooms_with_no_rooms(self):
+ self.hs.config.auto_join_rooms = []
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_room_is_another_domain(self):
+ self.hs.config.auto_join_rooms = ["#room:another"]
+ frank = UserID.from_string("@frank:test")
+ res = yield self.handler.register(frank.localpart)
+ self.assertEqual(res[0], frank.to_string())
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
+
+ @defer.inlineCallbacks
+ def test_auto_create_auto_join_where_auto_create_is_false(self):
+ self.hs.config.autocreate_auto_join_rooms = False
+ room_alias_str = "#room:test"
+ self.hs.config.auto_join_rooms = [room_alias_str]
+ res = yield self.handler.register(localpart='jeff')
+ rooms = yield self.store.get_rooms_for_user(res[0])
+ self.assertEqual(len(rooms), 0)
diff --git a/tests/state/__init__.py b/tests/state/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/state/__init__.py
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
new file mode 100644
index 0000000000..efd85ebe6c
--- /dev/null
+++ b/tests/state/test_v2.py
@@ -0,0 +1,663 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+
+from six.moves import zip
+
+import attr
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.event_auth import auth_types_for_event
+from synapse.events import FrozenEvent
+from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.types import EventID
+
+from tests import unittest
+
+ALICE = "@alice:example.com"
+BOB = "@bob:example.com"
+CHARLIE = "@charlie:example.com"
+EVELYN = "@evelyn:example.com"
+ZARA = "@zara:example.com"
+
+ROOM_ID = "!test:example.com"
+
+MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
+MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
+
+
+ORIGIN_SERVER_TS = 0
+
+
+class FakeEvent(object):
+ """A fake event we use as a convenience.
+
+ NOTE: Again as a convenience we use "node_ids" rather than event_ids to
+ refer to events. The event_id has node_id as localpart and example.com
+ as domain.
+ """
+ def __init__(self, id, sender, type, state_key, content):
+ self.node_id = id
+ self.event_id = EventID(id, "example.com").to_string()
+ self.sender = sender
+ self.type = type
+ self.state_key = state_key
+ self.content = content
+
+ def to_event(self, auth_events, prev_events):
+ """Given the auth_events and prev_events, convert to a Frozen Event
+
+ Args:
+ auth_events (list[str]): list of event_ids
+ prev_events (list[str]): list of event_ids
+
+ Returns:
+ FrozenEvent
+ """
+ global ORIGIN_SERVER_TS
+
+ ts = ORIGIN_SERVER_TS
+ ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1
+
+ event_dict = {
+ "auth_events": [(a, {}) for a in auth_events],
+ "prev_events": [(p, {}) for p in prev_events],
+ "event_id": self.node_id,
+ "sender": self.sender,
+ "type": self.type,
+ "content": self.content,
+ "origin_server_ts": ts,
+ "room_id": ROOM_ID,
+ }
+
+ if self.state_key is not None:
+ event_dict["state_key"] = self.state_key
+
+ return FrozenEvent(event_dict)
+
+
+# All graphs start with this set of events
+INITIAL_EVENTS = [
+ FakeEvent(
+ id="CREATE",
+ sender=ALICE,
+ type=EventTypes.Create,
+ state_key="",
+ content={"creator": ALICE},
+ ),
+ FakeEvent(
+ id="IMA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IPOWER",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100}},
+ ),
+ FakeEvent(
+ id="IJR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.PUBLIC},
+ ),
+ FakeEvent(
+ id="IMB",
+ sender=BOB,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMC",
+ sender=CHARLIE,
+ type=EventTypes.Member,
+ state_key=CHARLIE,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="IMZ",
+ sender=ZARA,
+ type=EventTypes.Member,
+ state_key=ZARA,
+ content=MEMBERSHIP_CONTENT_JOIN,
+ ),
+ FakeEvent(
+ id="START",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="END",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+]
+
+INITIAL_EDGES = [
+ "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
+]
+
+
+class StateTestCase(unittest.TestCase):
+ def test_ban_vs_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="MA",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=ALICE,
+ content={"membership": Membership.JOIN},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "MA", "PA", "START"],
+ ["END", "PB", "PA"],
+ ]
+
+ expected_state_ids = ["PA", "MA", "MB"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_join_rule_evasion(self):
+ events = [
+ FakeEvent(
+ id="JR",
+ sender=ALICE,
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rules": JoinRules.PRIVATE},
+ ),
+ FakeEvent(
+ id="ME",
+ sender=EVELYN,
+ type=EventTypes.Member,
+ state_key=EVELYN,
+ content={"membership": Membership.JOIN},
+ ),
+ ]
+
+ edges = [
+ ["END", "JR", "START"],
+ ["END", "ME", "START"],
+ ]
+
+ expected_state_ids = ["JR"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_offtopic_pl(self):
+ events = [
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ }
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PC",
+ sender=CHARLIE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ CHARLIE: 0,
+ },
+ },
+ ),
+ ]
+
+ edges = [
+ ["END", "PC", "PB", "PA", "START"],
+ ["END", "PA"],
+ ]
+
+ expected_state_ids = ["PC"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_basic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["PA2", "T2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic_reset(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MB",
+ sender=ALICE,
+ type=EventTypes.Member,
+ state_key=BOB,
+ content={"membership": Membership.BAN},
+ ),
+ ]
+
+ edges = [
+ ["END", "MB", "T2", "PA", "T1", "START"],
+ ["END", "T1"],
+ ]
+
+ expected_state_ids = ["T1", "MB", "PA"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def test_topic(self):
+ events = [
+ FakeEvent(
+ id="T1",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T2",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 0,
+ },
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key='',
+ content={
+ "users": {
+ ALICE: 100,
+ BOB: 50,
+ },
+ },
+ ),
+ FakeEvent(
+ id="T3",
+ sender=BOB,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ FakeEvent(
+ id="MZ1",
+ sender=ZARA,
+ type=EventTypes.Message,
+ state_key=None,
+ content={},
+ ),
+ FakeEvent(
+ id="T4",
+ sender=ALICE,
+ type=EventTypes.Topic,
+ state_key="",
+ content={},
+ ),
+ ]
+
+ edges = [
+ ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "MZ1", "T3", "PB", "PA1"],
+ ]
+
+ expected_state_ids = ["T4", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
+ def do_check(self, events, edges, expected_state_ids):
+ """Take a list of events and edges and calculate the state of the
+ graph at END, and asserts it matches `expected_state_ids`
+
+ Args:
+ events (list[FakeEvent])
+ edges (list[list[str]]): A list of chains of event edges, e.g.
+ `[[A, B, C]]` are edges A->B and B->C.
+ expected_state_ids (list[str]): The expected state at END, (excluding
+ the keys that haven't changed since START).
+ """
+ # We want to sort the events into topological order for processing.
+ graph = {}
+
+ # node_id -> FakeEvent
+ fake_event_map = {}
+
+ for ev in itertools.chain(INITIAL_EVENTS, events):
+ graph[ev.node_id] = set()
+ fake_event_map[ev.node_id] = ev
+
+ for a, b in pairwise(INITIAL_EDGES):
+ graph[a].add(b)
+
+ for edge_list in edges:
+ for a, b in pairwise(edge_list):
+ graph[a].add(b)
+
+ # event_id -> FrozenEvent
+ event_map = {}
+ # node_id -> state
+ state_at_event = {}
+
+ # We copy the map as the sort consumes the graph
+ graph_copy = {k: set(v) for k, v in graph.items()}
+
+ for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e):
+ fake_event = fake_event_map[node_id]
+ event_id = fake_event.event_id
+
+ prev_events = list(graph[node_id])
+
+ if len(prev_events) == 0:
+ state_before = {}
+ elif len(prev_events) == 1:
+ state_before = dict(state_at_event[prev_events[0]])
+ else:
+ state_d = resolve_events_with_store(
+ [state_at_event[n] for n in prev_events],
+ event_map=event_map,
+ state_res_store=TestStateResolutionStore(event_map),
+ )
+
+ self.assertTrue(state_d.called)
+ state_before = state_d.result
+
+ state_after = dict(state_before)
+ if fake_event.state_key is not None:
+ state_after[(fake_event.type, fake_event.state_key)] = event_id
+
+ auth_types = set(auth_types_for_event(fake_event))
+
+ auth_events = []
+ for key in auth_types:
+ if key in state_before:
+ auth_events.append(state_before[key])
+
+ event = fake_event.to_event(auth_events, prev_events)
+
+ state_at_event[node_id] = state_after
+ event_map[event_id] = event
+
+ expected_state = {}
+ for node_id in expected_state_ids:
+ # expected_state_ids are node IDs rather than event IDs,
+ # so we have to convert
+ event_id = EventID(node_id, "example.com").to_string()
+ event = event_map[event_id]
+
+ key = (event.type, event.state_key)
+
+ expected_state[key] = event_id
+
+ start_state = state_at_event["START"]
+ end_state = {
+ key: value
+ for key, value in state_at_event["END"].items()
+ if key in expected_state or start_state.get(key) != value
+ }
+
+ self.assertEqual(expected_state, end_state)
+
+
+class LexicographicalTestCase(unittest.TestCase):
+ def test_simple(self):
+ graph = {
+ "l": {"o"},
+ "m": {"n", "o"},
+ "n": {"o"},
+ "o": set(),
+ "p": {"o"},
+ }
+
+ res = list(lexicographical_topological_sort(graph, key=lambda x: x))
+
+ self.assertEqual(["o", "l", "n", "m", "p"], res)
+
+
+def pairwise(iterable):
+ "s -> (s0,s1), (s1,s2), (s2, s3), ..."
+ a, b = itertools.tee(iterable)
+ next(b, None)
+ return zip(a, b)
+
+
+@attr.s
+class TestStateResolutionStore(object):
+ event_map = attr.ib()
+
+ def get_events(self, event_ids, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ """
+
+ return {
+ eid: self.event_map[eid]
+ for eid in event_ids
+ if eid in self.event_map
+ }
+
+ def get_auth_chain(self, event_ids):
+ """Gets the full auth chain for a set of events (including rejected
+ events).
+
+ Includes the given event IDs in the result.
+
+ Note that:
+ 1. All events must be state events.
+ 2. For v1 rooms this may not have the full auth chain in the
+ presence of rejected events
+
+ Args:
+ event_ids (list): The event IDs of the events to fetch the auth
+ chain for. Must be state events.
+
+ Returns:
+ Deferred[list[str]]: List of event IDs of the auth chain.
+ """
+
+ # Simple DFS for auth chain
+ result = set()
+ stack = list(event_ids)
+ while stack:
+ event_id = stack.pop()
+ if event_id in result:
+ continue
+
+ result.add(event_id)
+
+ event = self.event_map[event_id]
+ for aid, _ in event.auth_events:
+ stack.append(aid)
+
+ return list(result)
diff --git a/tests/utils.py b/tests/utils.py
index dd347a0c59..565bb60d08 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -124,6 +124,7 @@ def default_config(name):
config.user_consent_server_notice_content = None
config.block_events_without_consent_error = None
config.media_storage_providers = []
+ config.autocreate_auto_join_rooms = True
config.auto_join_rooms = []
config.limit_usage_by_mau = False
config.hs_disabled = False
|