summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2019-11-26 10:53:48 +0000
committerGitHub <noreply@github.com>2019-11-26 10:53:48 +0000
commit4c1b799e1b9f8e04f6cd86a0977e28685ac20aea (patch)
treefcaea2671bf38721520191c13bb37adf18ef84c5 /tests
parentFormat changelog (diff)
parentMake sure that we close cursors before returning from a query (#6408) (diff)
downloadsynapse-4c1b799e1b9f8e04f6cd86a0977e28685ac20aea.tar.xz
Merge branch 'develop' into babolivier/context_filters
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_federation.py126
-rw-r--r--tests/rest/admin/test_admin.py6
-rw-r--r--tests/rest/client/v1/test_rooms.py173
-rw-r--r--tests/rest/media/v1/test_url_preview.py35
-rw-r--r--tests/server.py2
-rw-r--r--tests/storage/test_purge.py15
-rw-r--r--tests/test_state.py61
7 files changed, 399 insertions, 19 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d56220f403..b4d92cf732 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,13 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, Codes
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.logging.context import LoggingContext, run_in_background
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
 
 from tests import unittest
 
+logger = logging.getLogger(__name__)
+
 
 class FederationTestCase(unittest.HomeserverTestCase):
     servlets = [
@@ -79,3 +85,123 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(failure.code, 403, failure)
         self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
         self.assertEqual(failure.msg, "You are not invited to this room.")
+
+    def test_rejected_message_event_state(self):
+        """
+        Check that we store the state group correctly for rejected non-state events.
+
+        Regression test for #6289.
+        """
+        OTHER_SERVER = "otherserver"
+        OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        # pretend that another server has joined
+        join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+        # check the state group
+        sg = self.successResultOf(
+            self.store._get_state_group_for_event(join_event.event_id)
+        )
+
+        # build and send an event which will be rejected
+        ev = event_from_pdu_json(
+            {
+                "type": EventTypes.Message,
+                "content": {},
+                "room_id": room_id,
+                "sender": "@yetanotheruser:" + OTHER_SERVER,
+                "depth": join_event["depth"] + 1,
+                "prev_events": [join_event.event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            join_event.format_version,
+        )
+
+        with LoggingContext(request="send_rejected"):
+            d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+        self.get_success(d)
+
+        # that should have been rejected
+        e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+        self.assertIsNotNone(e.rejected_reason)
+
+        # ... and the state group should be the same as before
+        sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+        self.assertEqual(sg, sg2)
+
+    def test_rejected_state_event_state(self):
+        """
+        Check that we store the state group correctly for rejected state events.
+
+        Regression test for #6289.
+        """
+        OTHER_SERVER = "otherserver"
+        OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        # pretend that another server has joined
+        join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+        # check the state group
+        sg = self.successResultOf(
+            self.store._get_state_group_for_event(join_event.event_id)
+        )
+
+        # build and send an event which will be rejected
+        ev = event_from_pdu_json(
+            {
+                "type": "org.matrix.test",
+                "state_key": "test_key",
+                "content": {},
+                "room_id": room_id,
+                "sender": "@yetanotheruser:" + OTHER_SERVER,
+                "depth": join_event["depth"] + 1,
+                "prev_events": [join_event.event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            join_event.format_version,
+        )
+
+        with LoggingContext(request="send_rejected"):
+            d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+        self.get_success(d)
+
+        # that should have been rejected
+        e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+        self.assertIsNotNone(e.rejected_reason)
+
+        # ... and the state group should be the same as before
+        sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+        self.assertEqual(sg, sg2)
+
+    def _build_and_send_join_event(self, other_server, other_user, room_id):
+        join_event = self.get_success(
+            self.handler.on_make_join_request(other_server, room_id, other_user)
+        )
+        # the auth code requires that a signature exists, but doesn't check that
+        # signature... go figure.
+        join_event.signatures[other_server] = {"x": "y"}
+        with LoggingContext(request="send_join"):
+            d = run_in_background(
+                self.handler.on_send_join_request, other_server, join_event
+            )
+        self.get_success(d)
+
+        # sanity-check: the room should show that the new user is a member
+        r = self.get_success(self.store.get_current_state_ids(room_id))
+        self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
+
+        return join_event
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 8e1ca8b738..9575058252 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -628,10 +628,12 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
             "local_invites",
             "room_account_data",
             "room_tags",
+            "state_groups",
+            "state_groups_state",
         ):
             count = self.get_success(
                 self.store._simple_select_one_onecol(
-                    table="events",
+                    table=table,
                     keyvalues={"room_id": room_id},
                     retcol="COUNT(*)",
                     desc="test_purge_room",
@@ -639,3 +641,5 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
             )
 
             self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+    test_purge_room.skip = "Disabled because it's currently broken"
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index c5d67fc1cd..c0b046a50c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -27,7 +27,9 @@ from twisted.internet import defer
 
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.handlers.pagination import PurgeStatus
 from synapse.rest.client.v1 import login, profile, room
+from synapse.util.stringutils import random_string
 
 from tests import unittest
 
@@ -813,6 +815,177 @@ class RoomMessageListTestCase(RoomBase):
         self.assertTrue("chunk" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
+    def test_filter_labels(self):
+        """Test that we can filter by a label."""
+        message_filter = json.dumps(
+            {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
+        )
+
+        events = self._test_filter_labels(message_filter)
+
+        self.assertEqual(len(events), 2, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+        self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+    def test_filter_not_labels(self):
+        """Test that we can filter by the absence of a label."""
+        message_filter = json.dumps(
+            {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
+        )
+
+        events = self._test_filter_labels(message_filter)
+
+        self.assertEqual(len(events), 3, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+        self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+        self.assertEqual(
+            events[2]["content"]["body"], "with two wrong labels", events[2]
+        )
+
+    def test_filter_labels_not_labels(self):
+        """Test that we can filter by both a label and the absence of another label."""
+        sync_filter = json.dumps(
+            {
+                "types": [EventTypes.Message],
+                "org.matrix.labels": ["#work"],
+                "org.matrix.not_labels": ["#notfun"],
+            }
+        )
+
+        events = self._test_filter_labels(sync_filter)
+
+        self.assertEqual(len(events), 1, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+    def _test_filter_labels(self, message_filter):
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with right label",
+                EventContentFields.LABELS: ["#fun"],
+            },
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "without label"},
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with wrong label",
+                EventContentFields.LABELS: ["#work"],
+            },
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with two wrong labels",
+                EventContentFields.LABELS: ["#work", "#notfun"],
+            },
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with right label",
+                EventContentFields.LABELS: ["#fun"],
+            },
+        )
+
+        token = "s0_0_0_0_0_0_0_0_0"
+        request, channel = self.make_request(
+            "GET",
+            "/rooms/%s/messages?access_token=x&from=%s&filter=%s"
+            % (self.room_id, token, message_filter),
+        )
+        self.render(request)
+
+        return channel.json_body["chunk"]
+
+    def test_room_messages_purge(self):
+        store = self.hs.get_datastore()
+        pagination_handler = self.hs.get_pagination_handler()
+
+        # Send a first message in the room, which will be removed by the purge.
+        first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
+        first_token = self.get_success(
+            store.get_topological_token_for_event(first_event_id)
+        )
+
+        # Send a second message in the room, which won't be removed, and which we'll
+        # use as the marker to purge events before.
+        second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
+        second_token = self.get_success(
+            store.get_topological_token_for_event(second_event_id)
+        )
+
+        # Send a third event in the room to ensure we don't fall under any edge case
+        # due to our marker being the latest forward extremity in the room.
+        self.helper.send(self.room_id, "message 3")
+
+        # Check that we get the first and second message when querying /messages.
+        request, channel = self.make_request(
+            "GET",
+            "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+            % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        chunk = channel.json_body["chunk"]
+        self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+
+        # Purge every event before the second event.
+        purge_id = random_string(16)
+        pagination_handler._purges_by_id[purge_id] = PurgeStatus()
+        self.get_success(
+            pagination_handler._purge_history(
+                purge_id=purge_id,
+                room_id=self.room_id,
+                token=second_token,
+                delete_local_events=True,
+            )
+        )
+
+        # Check that we only get the second message through /message now that the first
+        # has been purged.
+        request, channel = self.make_request(
+            "GET",
+            "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+            % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        chunk = channel.json_body["chunk"]
+        self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
+
+        # Check that we get no event, but also no error, when querying /messages with
+        # the token that was pointing at the first event, because we don't have it
+        # anymore.
+        request, channel = self.make_request(
+            "GET",
+            "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
+            % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        chunk = channel.json_body["chunk"]
+        self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
+
 
 class RoomSearchTestCase(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 976652aee8..852b8ab11c 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -247,6 +247,41 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
+    def test_overlong_title(self):
+        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+
+        end_content = (
+            b"<html><head>"
+            b"<title>" + b"x" * 2000 + b"</title>"
+            b'<meta property="og:description" content="hi" />'
+            b"</head></html>"
+        )
+
+        request, channel = self.make_request(
+            "GET", "url_preview?url=http://matrix.org", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            (
+                b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
+            )
+            % (len(end_content),)
+            + end_content
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        res = channel.json_body
+        # We should only see the `og:description` field, as `title` is too long and should be stripped out
+        self.assertCountEqual(["og:description"], res.keys())
+
     def test_ipaddr(self):
         """
         IP addresses can be previewed directly.
diff --git a/tests/server.py b/tests/server.py
index f878aeaada..2b7cf4242e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -379,6 +379,7 @@ class FakeTransport(object):
 
     disconnecting = False
     disconnected = False
+    connected = True
     buffer = attr.ib(default=b"")
     producer = attr.ib(default=None)
     autoflush = attr.ib(default=True)
@@ -402,6 +403,7 @@ class FakeTransport(object):
                     "FakeTransport: Delaying disconnect until buffer is flushed"
                 )
             else:
+                self.connected = False
                 self.disconnected = True
 
     def abortConnection(self):
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index f671599cb8..b9fafaa1a6 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase):
         third = self.helper.send(self.room_id, body="test3")
         last = self.helper.send(self.room_id, body="test4")
 
-        storage = self.hs.get_datastore()
+        store = self.hs.get_datastore()
+        storage = self.hs.get_storage()
 
         # Get the topological token
-        event = storage.get_topological_token_for_event(last["event_id"])
+        event = store.get_topological_token_for_event(last["event_id"])
         self.pump()
         event = self.successResultOf(event)
 
         # Purge everything before this topological token
-        purge = storage.purge_history(self.room_id, event, True)
+        purge = storage.purge_events.purge_history(self.room_id, event, True)
         self.pump()
         self.assertEqual(self.successResultOf(purge), None)
 
         # Try and get the events
-        get_first = storage.get_event(first["event_id"])
-        get_second = storage.get_event(second["event_id"])
-        get_third = storage.get_event(third["event_id"])
-        get_last = storage.get_event(last["event_id"])
+        get_first = store.get_event(first["event_id"])
+        get_second = store.get_event(second["event_id"])
+        get_third = store.get_event(third["event_id"])
+        get_last = store.get_event(last["event_id"])
         self.pump()
 
         # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
diff --git a/tests/test_state.py b/tests/test_state.py
index 38246555bd..176535947a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -21,6 +21,7 @@ from synapse.api.auth import Auth
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
 from synapse.state import StateHandler, StateResolutionHandler
 
 from tests import unittest
@@ -198,16 +199,22 @@ class StateTestCase(unittest.TestCase):
 
         self.store.register_events(graph.walk())
 
-        context_store = {}
+        context_store = {}  # type: dict[str, EventContext]
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        ctx_c = context_store["C"]
+        ctx_d = context_store["D"]
+
+        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
         self.assertEqual(2, len(prev_state_ids))
 
+        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
         graph = Graph(
@@ -241,12 +248,19 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        # C ends up winning the resolution between B and C
+
+        ctx_c = context_store["C"]
+        ctx_d = context_store["D"]
 
+        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
         self.assertSetEqual(
             {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
         )
 
+        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
     @defer.inlineCallbacks
     def test_branch_have_banned_conflict(self):
         graph = Graph(
@@ -292,11 +306,18 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+        # C ends up winning the resolution between C and D because bans win over other
+        # changes
+
+        ctx_c = context_store["C"]
+        ctx_e = context_store["E"]
 
+        prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
         self.assertSetEqual(
             {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
         )
+        self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+        self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
 
     @defer.inlineCallbacks
     def test_branch_have_perms_conflict(self):
@@ -360,12 +381,20 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        # B ends up winning the resolution between B and C because power levels
+        # win over other changes.
 
+        ctx_b = context_store["B"]
+        ctx_d = context_store["D"]
+
+        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
         )
 
+        self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
     def _add_depths(self, nodes, edges):
         def _get_depth(ev):
             node = nodes[ev]
@@ -390,13 +419,16 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        self.assertEqual(
-            set(e.event_id for e in old_state), set(current_state_ids.values())
+        current_state_ids = yield context.get_current_state_ids(self.store)
+        self.assertCountEqual(
+            (e.event_id for e in old_state), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group)
+        self.assertIsNotNone(context.state_group_before_event)
+        self.assertEqual(context.state_group_before_event, context.state_group)
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
@@ -411,11 +443,18 @@ class StateTestCase(unittest.TestCase):
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
         prev_state_ids = yield context.get_prev_state_ids(self.store)
+        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        self.assertEqual(
-            set(e.event_id for e in old_state), set(prev_state_ids.values())
+        current_state_ids = yield context.get_current_state_ids(self.store)
+        self.assertCountEqual(
+            (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
+        self.assertIsNotNone(context.state_group_before_event)
+        self.assertNotEqual(context.state_group_before_event, context.state_group)
+        self.assertEqual(context.state_group_before_event, context.prev_group)
+        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
         prev_event_id = "prev_event_id"