summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorWill Hunt <willh@matrix.org>2021-05-11 22:28:35 +0100
committerWill Hunt <willh@matrix.org>2021-05-11 22:28:35 +0100
commitdacc395dcae59824b0b30094fbc9a08a36df8113 (patch)
tree9225e85f329edc3bdb74332dcb837d6c888ff536 /tests
parentfix version (diff)
parentTests for to-device messages (#9965) (diff)
downloadsynapse-dacc395dcae59824b0b30094fbc9a08a36df8113.tar.xz
Merge remote-tracking branch 'origin/develop' into hs/hacked-together-event-cache
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_server.py19
-rw-r--r--tests/handlers/test_presence.py14
-rw-r--r--tests/handlers/test_space_summary.py81
-rw-r--r--tests/push/test_push_rule_evaluator.py166
-rw-r--r--tests/rest/client/v2_alpha/test_sendtodevice.py201
-rw-r--r--tests/storage/test_cleanup_extrems.py4
-rw-r--r--tests/util/test_glob_to_regex.py59
7 files changed, 531 insertions, 13 deletions
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 8508b6bd3b..1737891564 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -74,6 +74,25 @@ class ServerACLsTestCase(unittest.TestCase):
         self.assertFalse(server_matches_acl_event("[1:2::]", e))
         self.assertTrue(server_matches_acl_event("1:2:3:4", e))
 
+    def test_wildcard_matching(self):
+        e = _create_acl_event({"allow": ["good*.com"]})
+        self.assertTrue(
+            server_matches_acl_event("good.com", e),
+            "* matches 0 characters",
+        )
+        self.assertTrue(
+            server_matches_acl_event("GOOD.COM", e),
+            "pattern is case-insensitive",
+        )
+        self.assertTrue(
+            server_matches_acl_event("good.aa.com", e),
+            "* matches several characters, including '.'",
+        )
+        self.assertFalse(
+            server_matches_acl_event("ishgood.com", e),
+            "pattern does not allow prefixes",
+        )
+
 
 class StateQueryTests(unittest.FederatingHomeserverTestCase):
 
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index ce330e79cc..1ffab709fc 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -729,7 +729,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(expected_state.state, PresenceState.ONLINE)
         self.federation_sender.send_presence_to_destinations.assert_called_once_with(
-            destinations=["server2"], states={expected_state}
+            destinations={"server2"}, states=[expected_state]
         )
 
         #
@@ -740,7 +740,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         self._add_new_user(room_id, "@bob:server3")
 
         self.federation_sender.send_presence_to_destinations.assert_called_once_with(
-            destinations=["server3"], states={expected_state}
+            destinations={"server3"}, states=[expected_state]
         )
 
     def test_remote_gets_presence_when_local_user_joins(self):
@@ -788,14 +788,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
             self.presence_handler.current_state_for_user("@test2:server")
         )
         self.assertEqual(expected_state.state, PresenceState.ONLINE)
-        self.assertEqual(
-            self.federation_sender.send_presence_to_destinations.call_count, 2
-        )
-        self.federation_sender.send_presence_to_destinations.assert_any_call(
-            destinations=["server3"], states={expected_state}
-        )
-        self.federation_sender.send_presence_to_destinations.assert_any_call(
-            destinations=["server2"], states={expected_state}
+        self.federation_sender.send_presence_to_destinations.assert_called_once_with(
+            destinations={"server2", "server3"}, states=[expected_state]
         )
 
     def _add_new_user(self, room_id, user_id):
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
new file mode 100644
index 0000000000..2c5e81531b
--- /dev/null
+++ b/tests/handlers/test_space_summary.py
@@ -0,0 +1,81 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from typing import Any, Optional
+from unittest import mock
+
+from synapse.handlers.space_summary import _child_events_comparison_key
+
+from tests import unittest
+
+
+def _create_event(room_id: str, order: Optional[Any] = None):
+    result = mock.Mock()
+    result.room_id = room_id
+    result.content = {}
+    if order is not None:
+        result.content["order"] = order
+    return result
+
+
+def _order(*events):
+    return sorted(events, key=_child_events_comparison_key)
+
+
+class TestSpaceSummarySort(unittest.TestCase):
+    def test_no_order_last(self):
+        """An event with no ordering is placed behind those with an ordering."""
+        ev1 = _create_event("!abc:test")
+        ev2 = _create_event("!xyz:test", "xyz")
+
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+    def test_order(self):
+        """The ordering should be used."""
+        ev1 = _create_event("!abc:test", "xyz")
+        ev2 = _create_event("!xyz:test", "abc")
+
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+    def test_order_room_id(self):
+        """Room ID is a tie-breaker for ordering."""
+        ev1 = _create_event("!abc:test", "abc")
+        ev2 = _create_event("!xyz:test", "abc")
+
+        self.assertEqual([ev1, ev2], _order(ev1, ev2))
+
+    def test_invalid_ordering_type(self):
+        """Invalid orderings are considered the same as missing."""
+        ev1 = _create_event("!abc:test", 1)
+        ev2 = _create_event("!xyz:test", "xyz")
+
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+        ev1 = _create_event("!abc:test", {})
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+        ev1 = _create_event("!abc:test", [])
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+        ev1 = _create_event("!abc:test", True)
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+    def test_invalid_ordering_value(self):
+        """Invalid orderings are considered the same as missing."""
+        ev1 = _create_event("!abc:test", "foo\n")
+        ev2 = _create_event("!xyz:test", "xyz")
+
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
+
+        ev1 = _create_event("!abc:test", "a" * 51)
+        self.assertEqual([ev2, ev1], _order(ev1, ev2))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 45906ce720..a52e89e407 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, Dict
+
 from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent
 from synapse.push import push_rule_evaluator
@@ -66,6 +68,170 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         # A display name with spaces should work fine.
         self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
 
+    def _assert_matches(
+        self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+    ) -> None:
+        evaluator = self._get_evaluator(content)
+        self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
+
+    def _assert_not_matches(
+        self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+    ) -> None:
+        evaluator = self._get_evaluator(content)
+        self.assertFalse(
+            evaluator.matches(condition, "@user:test", "display_name"), msg
+        )
+
+    def test_event_match_body(self):
+        """Check that event_match conditions on content.body work as expected"""
+
+        # if the key is `content.body`, the pattern matches substrings.
+
+        # non-wildcards should match
+        condition = {
+            "kind": "event_match",
+            "key": "content.body",
+            "pattern": "foobaz",
+        }
+        self._assert_matches(
+            condition,
+            {"body": "aaa FoobaZ zzz"},
+            "patterns should match and be case-insensitive",
+        )
+        self._assert_not_matches(
+            condition,
+            {"body": "aa xFoobaZ yy"},
+            "pattern should only match at word boundaries",
+        )
+        self._assert_not_matches(
+            condition,
+            {"body": "aa foobazx yy"},
+            "pattern should only match at word boundaries",
+        )
+
+        # wildcards should match
+        condition = {
+            "kind": "event_match",
+            "key": "content.body",
+            "pattern": "f?o*baz",
+        }
+
+        self._assert_matches(
+            condition,
+            {"body": "aaa FoobarbaZ zzz"},
+            "* should match string and pattern should be case-insensitive",
+        )
+        self._assert_matches(
+            condition, {"body": "aa foobaz yy"}, "* should match 0 characters"
+        )
+        self._assert_not_matches(
+            condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters"
+        )
+        self._assert_not_matches(
+            condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters"
+        )
+        self._assert_not_matches(
+            condition,
+            {"body": "aa xfooxbaz yy"},
+            "pattern should only match at word boundaries",
+        )
+        self._assert_not_matches(
+            condition,
+            {"body": "aa fooxbazx yy"},
+            "pattern should only match at word boundaries",
+        )
+
+        # test backslashes
+        condition = {
+            "kind": "event_match",
+            "key": "content.body",
+            "pattern": r"f\oobaz",
+        }
+        self._assert_matches(
+            condition,
+            {"body": r"F\oobaz"},
+            "backslash should match itself",
+        )
+        condition = {
+            "kind": "event_match",
+            "key": "content.body",
+            "pattern": r"f\?obaz",
+        }
+        self._assert_matches(
+            condition,
+            {"body": r"F\oobaz"},
+            r"? after \ should match any character",
+        )
+
+    def test_event_match_non_body(self):
+        """Check that event_match conditions on other keys work as expected"""
+
+        # if the key is anything other than 'content.body', the pattern must match the
+        # whole value.
+
+        # non-wildcards should match
+        condition = {
+            "kind": "event_match",
+            "key": "content.value",
+            "pattern": "foobaz",
+        }
+        self._assert_matches(
+            condition,
+            {"value": "FoobaZ"},
+            "patterns should match and be case-insensitive",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "xFoobaZ"},
+            "pattern should only match at the start/end of the value",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "FoobaZz"},
+            "pattern should only match at the start/end of the value",
+        )
+
+        # wildcards should match
+        condition = {
+            "kind": "event_match",
+            "key": "content.value",
+            "pattern": "f?o*baz",
+        }
+        self._assert_matches(
+            condition,
+            {"value": "FoobarbaZ"},
+            "* should match string and pattern should be case-insensitive",
+        )
+        self._assert_matches(
+            condition, {"value": "foobaz"}, "* should match 0 characters"
+        )
+        self._assert_not_matches(
+            condition, {"value": "fobbaz"}, "? should not match 0 characters"
+        )
+        self._assert_not_matches(
+            condition, {"value": "fiiobaz"}, "? should not match 2 characters"
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "xfooxbaz"},
+            "pattern should only match at the start/end of the value",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "fooxbazx"},
+            "pattern should only match at the start/end of the value",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "x\nfooxbaz"},
+            "pattern should not match after a newline",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "fooxbaz\nx"},
+            "pattern should not match before a newline",
+        )
+
     def test_no_body(self):
         """Not having a body shouldn't break the evaluator."""
         evaluator = self._get_evaluator({})
diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/v2_alpha/test_sendtodevice.py
new file mode 100644
index 0000000000..c9c99cc5d7
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_sendtodevice.py
@@ -0,0 +1,201 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import sendtodevice, sync
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+class SendToDeviceTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        sendtodevice.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def test_user_to_user(self):
+        """A to-device message from one user to another should get delivered"""
+
+        user1 = self.register_user("u1", "pass")
+        user1_tok = self.login("u1", "pass", "d1")
+
+        user2 = self.register_user("u2", "pass")
+        user2_tok = self.login("u2", "pass", "d2")
+
+        # send the message
+        test_msg = {"foo": "bar"}
+        chan = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.test/1234",
+            content={"messages": {user2: {"d2": test_msg}}},
+            access_token=user1_tok,
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+
+        # check it appears
+        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        self.assertEqual(channel.code, 200, channel.result)
+        expected_result = {
+            "events": [
+                {
+                    "sender": user1,
+                    "type": "m.test",
+                    "content": test_msg,
+                }
+            ]
+        }
+        self.assertEqual(channel.json_body["to_device"], expected_result)
+
+        # it should re-appear if we do another sync
+        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.json_body["to_device"], expected_result)
+
+        # it should *not* appear if we do an incremental sync
+        sync_token = channel.json_body["next_batch"]
+        channel = self.make_request(
+            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
+
+    @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
+    def test_local_room_key_request(self):
+        """m.room_key_request has special-casing; test from local user"""
+        user1 = self.register_user("u1", "pass")
+        user1_tok = self.login("u1", "pass", "d1")
+
+        user2 = self.register_user("u2", "pass")
+        user2_tok = self.login("u2", "pass", "d2")
+
+        # send three messages
+        for i in range(3):
+            chan = self.make_request(
+                "PUT",
+                f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}",
+                content={"messages": {user2: {"d2": {"idx": i}}}},
+                access_token=user1_tok,
+            )
+            self.assertEqual(chan.code, 200, chan.result)
+
+        # now sync: we should get two of the three
+        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        self.assertEqual(channel.code, 200, channel.result)
+        msgs = channel.json_body["to_device"]["events"]
+        self.assertEqual(len(msgs), 2)
+        for i in range(2):
+            self.assertEqual(
+                msgs[i],
+                {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
+            )
+        sync_token = channel.json_body["next_batch"]
+
+        # ... time passes
+        self.reactor.advance(1)
+
+        # and we can send more messages
+        chan = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.room_key_request/3",
+            content={"messages": {user2: {"d2": {"idx": 3}}}},
+            access_token=user1_tok,
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+
+        # ... which should arrive
+        channel = self.make_request(
+            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        msgs = channel.json_body["to_device"]["events"]
+        self.assertEqual(len(msgs), 1)
+        self.assertEqual(
+            msgs[0],
+            {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}},
+        )
+
+    @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
+    def test_remote_room_key_request(self):
+        """m.room_key_request has special-casing; test from remote user"""
+        user2 = self.register_user("u2", "pass")
+        user2_tok = self.login("u2", "pass", "d2")
+
+        federation_registry = self.hs.get_federation_registry()
+
+        # send three messages
+        for i in range(3):
+            self.get_success(
+                federation_registry.on_edu(
+                    "m.direct_to_device",
+                    "remote_server",
+                    {
+                        "sender": "@user:remote_server",
+                        "type": "m.room_key_request",
+                        "messages": {user2: {"d2": {"idx": i}}},
+                        "message_id": f"{i}",
+                    },
+                )
+            )
+
+        # now sync: we should get two of the three
+        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        self.assertEqual(channel.code, 200, channel.result)
+        msgs = channel.json_body["to_device"]["events"]
+        self.assertEqual(len(msgs), 2)
+        for i in range(2):
+            self.assertEqual(
+                msgs[i],
+                {
+                    "sender": "@user:remote_server",
+                    "type": "m.room_key_request",
+                    "content": {"idx": i},
+                },
+            )
+        sync_token = channel.json_body["next_batch"]
+
+        # ... time passes
+        self.reactor.advance(1)
+
+        # and we can send more messages
+        self.get_success(
+            federation_registry.on_edu(
+                "m.direct_to_device",
+                "remote_server",
+                {
+                    "sender": "@user:remote_server",
+                    "type": "m.room_key_request",
+                    "messages": {user2: {"d2": {"idx": 3}}},
+                    "message_id": "3",
+                },
+            )
+        )
+
+        # ... which should arrive
+        channel = self.make_request(
+            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        msgs = channel.json_body["to_device"]["events"]
+        self.assertEqual(len(msgs), 1)
+        self.assertEqual(
+            msgs[0],
+            {
+                "sender": "@user:remote_server",
+                "type": "m.room_key_request",
+                "content": {"idx": 3},
+            },
+        )
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index aa20588bbe..77c4fe721c 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -47,10 +47,8 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         )
 
         schema_path = os.path.join(
-            prepare_database.dir_path,
-            "databases",
+            prepare_database.schema_path,
             "main",
-            "schema",
             "delta",
             "54",
             "delete_forward_extremities.sql",
diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py
new file mode 100644
index 0000000000..220accb92b
--- /dev/null
+++ b/tests/util/test_glob_to_regex.py
@@ -0,0 +1,59 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util import glob_to_regex
+
+from tests.unittest import TestCase
+
+
+class GlobToRegexTestCase(TestCase):
+    def test_literal_match(self):
+        """patterns without wildcards should match"""
+        pat = glob_to_regex("foobaz")
+        self.assertTrue(
+            pat.match("FoobaZ"), "patterns should match and be case-insensitive"
+        )
+        self.assertFalse(
+            pat.match("x foobaz"), "pattern should not match at word boundaries"
+        )
+
+    def test_wildcard_match(self):
+        pat = glob_to_regex("f?o*baz")
+
+        self.assertTrue(
+            pat.match("FoobarbaZ"),
+            "* should match string and pattern should be case-insensitive",
+        )
+        self.assertTrue(pat.match("foobaz"), "* should match 0 characters")
+        self.assertFalse(pat.match("fooxaz"), "the character after * must match")
+        self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters")
+        self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters")
+
+    def test_multi_wildcard(self):
+        """patterns with multiple wildcards in a row should match"""
+        pat = glob_to_regex("**baz")
+        self.assertTrue(pat.match("agsgsbaz"), "** should match any string")
+        self.assertTrue(pat.match("baz"), "** should match the empty string")
+        self.assertEqual(pat.pattern, r"\A.{0,}baz\Z")
+
+        pat = glob_to_regex("*?baz")
+        self.assertTrue(pat.match("agsgsbaz"), "*? should match any string")
+        self.assertTrue(pat.match("abaz"), "*? should match a single char")
+        self.assertFalse(pat.match("baz"), "*? should not match the empty string")
+        self.assertEqual(pat.pattern, r"\A.{1,}baz\Z")
+
+        pat = glob_to_regex("a?*?*?baz")
+        self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars")
+        self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars")
+        self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars")
+        self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z")