summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_sender.py6
-rw-r--r--tests/handlers/test_e2e_keys.py18
-rw-r--r--tests/handlers/test_e2e_room_keys.py1
-rw-r--r--tests/handlers/test_user_directory.py91
-rw-r--r--tests/push/test_push_rule_evaluator.py39
-rw-r--r--tests/replication/tcp/streams/test_events.py74
-rw-r--r--tests/replication/tcp/streams/test_typing.py88
-rw-r--r--tests/test_federation.py2
8 files changed, 282 insertions, 37 deletions
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ff12539041..1a9bd5f37d 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -21,6 +21,7 @@ from signedjson.types import BaseKey, SigningKey
 
 from twisted.internet import defer
 
+from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.rest import admin
 from synapse.rest.client.v1 import login
 from synapse.types import JsonDict, ReadReceipt
@@ -536,7 +537,10 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
     return {
         "user_id": user_id,
         "device_id": device_id,
-        "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+        "algorithms": [
+            "m.olm.curve25519-aes-sha2",
+            RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+        ],
         "keys": {
             "curve25519:" + device_id: "curve25519+key",
             key_id(sk): encode_pubkey(sk),
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index e1e144b2e7..6c1dc72bd1 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 import synapse.handlers.e2e_keys
 import synapse.storage
 from synapse.api import errors
+from synapse.api.constants import RoomEncryptionAlgorithms
 
 from tests import unittest, utils
 
@@ -222,7 +223,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         device_key_1 = {
             "user_id": local_user,
             "device_id": "abc",
-            "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
             "keys": {
                 "ed25519:abc": "base64+ed25519+key",
                 "curve25519:abc": "base64+curve25519+key",
@@ -232,7 +236,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         device_key_2 = {
             "user_id": local_user,
             "device_id": "def",
-            "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
             "keys": {
                 "ed25519:def": "base64+ed25519+key",
                 "curve25519:def": "base64+curve25519+key",
@@ -315,7 +322,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         device_key = {
             "user_id": local_user,
             "device_id": device_id,
-            "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
             "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
             "signatures": {local_user: {"ed25519:xyz": "something"}},
         }
@@ -392,7 +402,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                         "device_id": device_id,
                         "algorithms": [
                             "m.olm.curve25519-aes-sha2",
-                            "m.megolm.v1.aes-sha2",
+                            RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
                         ],
                         "keys": {
                             "curve25519:xyz": "curve25519+key",
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 70f172eb02..822ea42dde 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -96,6 +96,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # check we can retrieve it as the current version
         res = yield self.handler.get_version_info(self.local_user)
         version_etag = res["etag"]
+        self.assertIsInstance(version_etag, str)
         del res["etag"]
         self.assertDictEqual(
             res,
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c15bce5bef..23fcc372dd 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -17,12 +17,13 @@ from mock import Mock
 from twisted.internet import defer
 
 import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import user_directory
 from synapse.storage.roommember import ProfileInfo
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class UserDirectoryTestCase(unittest.HomeserverTestCase):
@@ -147,6 +148,94 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user3", 10))
         self.assertEqual(len(s["results"]), 0)
 
+    @override_config({"encryption_enabled_by_default_for_room_type": "all"})
+    def test_encrypted_by_default_config_option_all(self):
+        """Tests that invite-only and non-invite-only rooms have encryption enabled by
+        default when the config option encryption_enabled_by_default_for_room_type is "all".
+        """
+        # Create a user
+        user = self.register_user("user", "pass")
+        user_token = self.login(user, "pass")
+
+        # Create an invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+        # Check that the room has an encryption state event
+        event_content = self.helper.get_state(
+            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+        )
+        self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+        # Create a non invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+        # Check that the room has an encryption state event
+        event_content = self.helper.get_state(
+            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+        )
+        self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+    @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
+    def test_encrypted_by_default_config_option_invite(self):
+        """Tests that only new, invite-only rooms have encryption enabled by default when
+        the config option encryption_enabled_by_default_for_room_type is "invite".
+        """
+        # Create a user
+        user = self.register_user("user", "pass")
+        user_token = self.login(user, "pass")
+
+        # Create an invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+        # Check that the room has an encryption state event
+        event_content = self.helper.get_state(
+            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+        )
+        self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+        # Create a non invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+        # Check that the room does not have an encryption state event
+        self.helper.get_state(
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
+            expect_code=404,
+        )
+
+    @override_config({"encryption_enabled_by_default_for_room_type": "off"})
+    def test_encrypted_by_default_config_option_off(self):
+        """Tests that neither new invite-only nor non-invite-only rooms have encryption
+        enabled by default when the config option
+        encryption_enabled_by_default_for_room_type is "off".
+        """
+        # Create a user
+        user = self.register_user("user", "pass")
+        user_token = self.login(user, "pass")
+
+        # Create an invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+        # Check that the room does not have an encryption state event
+        self.helper.get_state(
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
+            expect_code=404,
+        )
+
+        # Create a non invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+        # Check that the room does not have an encryption state event
+        self.helper.get_state(
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
+            expect_code=404,
+        )
+
     def test_spam_checker(self):
         """
         A user which fails to the spam checks will not appear in search results.
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 9ae6a87d7b..af35d23aea 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -21,7 +21,7 @@ from tests import unittest
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def setUp(self):
+    def _get_evaluator(self, content):
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -29,37 +29,58 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
                 "sender": "@user:test",
                 "state_key": "",
                 "room_id": "@room:test",
-                "content": {"body": "foo bar baz"},
+                "content": content,
             },
             RoomVersions.V1,
         )
         room_member_count = 0
         sender_power_level = 0
         power_levels = {}
-        self.evaluator = PushRuleEvaluatorForEvent(
+        return PushRuleEvaluatorForEvent(
             event, room_member_count, sender_power_level, power_levels
         )
 
     def test_display_name(self):
         """Check for a matching display name in the body of the event."""
+        evaluator = self._get_evaluator({"body": "foo bar baz"})
+
         condition = {
             "kind": "contains_display_name",
         }
 
         # Blank names are skipped.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+        self.assertFalse(evaluator.matches(condition, "@user:test", ""))
 
         # Check a display name that doesn't match.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+        self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
 
         # Check a display name which matches.
-        self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
 
         # A display name that matches, but not a full word does not result in a match.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+        self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
 
         # A display name should not be interpreted as a regular expression.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+        self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
 
         # A display name with spaces should work fine.
-        self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
+
+    def test_no_body(self):
+        """Not having a body shouldn't break the evaluator."""
+        evaluator = self._get_evaluator({})
+
+        condition = {
+            "kind": "contains_display_name",
+        }
+        self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+    def test_invalid_body(self):
+        """A non-string body should not break the evaluator."""
+        condition = {
+            "kind": "contains_display_name",
+        }
+
+        for body in (1, True, {"foo": "bar"}):
+            evaluator = self._get_evaluator({"body": body})
+            self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..097e1653b4 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
 from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
 from synapse.replication.tcp.streams.events import (
     EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         # also one state event
         state_event = self._inject_state_event()
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -174,11 +170,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         # one more bit of state that doesn't get rolled back
         state2 = self._inject_state_event()
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -327,11 +318,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
             prev_events = [e.event_id]
             pl_events.append(e)
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +364,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
+    def test_backwards_stream_id(self):
+        """
+        Test that RDATA that comes after the current position should be discarded.
+        """
+        # disconnect, so that we can stack up some changes
+        self.disconnect()
+
+        # Generate an events. We inject them using inject_event so that they are
+        # not send out over replication until we call self.replicate().
+        event = self._inject_test_event()
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual([], self.test_handler.received_rdata_rows)
+
+        # now reconnect to pull the updates
+        self.reconnect()
+        self.replicate()
+
+        # We should have received the expected single row (as well as various
+        # cache invalidation updates which we ignore).
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+
+        # There should be a single received row.
+        self.assertEqual(len(received_rows), 1)
+
+        stream_name, token, row = received_rows[0]
+        self.assertEqual("events", stream_name)
+        self.assertIsInstance(row, EventsStreamRow)
+        self.assertEqual(row.type, "ev")
+        self.assertIsInstance(row.data, EventsStreamEventRow)
+        self.assertEqual(row.data.event_id, event.event_id)
+
+        # Reset the data.
+        self.test_handler.received_rdata_rows = []
+
+        # Save the current token for later.
+        worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+        prev_token = worker_events_stream.current_token("master")
+
+        # Manually send an old RDATA command, which should get dropped. This
+        # re-uses the row from above, but with an earlier stream token.
+        self.hs.get_tcp_replication().send_command(
+            RdataCommand("events", "master", 1, row)
+        )
+
+        # No updates have been received (because it was discard as old).
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+        self.assertEqual(len(received_rows), 0)
+
+        # Ensure the stream has not gone backwards.
+        current_token = worker_events_stream.current_token("master")
+        self.assertGreaterEqual(current_token, prev_token)
+
     event_count = 0
 
     def _inject_test_event(
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
 
 from synapse.handlers.typing import RoomMember
 from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from tests.replication._base import BaseStreamTestCase
 
 USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
     def test_typing(self):
         typing = self.hs.get_typing_handler()
 
-        room_id = "!bar:blue"
-
         self.reconnect()
 
-        typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
 
         self.reactor.advance(0)
 
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
-        self.assertEqual(room_id, row.room_id)
+        self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
         # Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
         self.test_handler.on_rdata.reset_mock()
 
-        typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
         self.test_handler.on_rdata.assert_not_called()
 
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
-        self.assertEqual(room_id, row.room_id)
+        self.assertEqual(ROOM_ID, row.room_id)
+        self.assertEqual([], row.user_ids)
+
+    def test_reset(self):
+        """
+        Test what happens when a typing stream resets.
+
+        This is emulated by jumping the stream ahead, then reconnecting (which
+        sends the proper position and RDATA).
+        """
+        typing = self.hs.get_typing_handler()
+
+        self.reconnect()
+
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+        self.reactor.advance(0)
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "typing")
+        self.assertEqual(1, len(rdata_rows))
+        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        self.assertEqual(ROOM_ID, row.room_id)
+        self.assertEqual([USER_ID], row.user_ids)
+
+        # Push the stream forward a bunch so it can be reset.
+        for i in range(100):
+            typing._push_update(
+                member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+            )
+        self.reactor.advance(0)
+
+        # Disconnect.
+        self.disconnect()
+
+        # Reset the typing handler
+        self.hs.get_replication_streams()["typing"].last_token = 0
+        self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+        typing._latest_room_serial = 0
+        typing._typing_stream_change_cache = StreamChangeCache(
+            "TypingStreamChangeCache", typing._latest_room_serial
+        )
+        typing._reset()
+
+        # Reconnect.
+        self.reconnect()
+        self.pump(0.1)
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+        # Reset the test code.
+        self.test_handler.on_rdata.reset_mock()
+        self.test_handler.on_rdata.assert_not_called()
+
+        # Push additional data.
+        typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+        self.reactor.advance(0)
+
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "typing")
+        self.assertEqual(1, len(rdata_rows))
+        row = rdata_rows[0]
+        self.assertEqual(ROOM_ID_2, row.room_id)
         self.assertEqual([], row.user_ids)
+
+        # The token should have been reset.
+        self.assertEqual(token, 1)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c662195eec..89dcc58b99 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -30,7 +30,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         room_creator = self.homeserver.get_room_creation_handler()
         room_deferred = ensureDeferred(
             room_creator.create_room(
-                our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+                our_user, room_creator._presets_dict["public_chat"], ratelimit=False
             )
         )
         self.reactor.advance(0.1)