diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index a269c477fb..a82c4eed86 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -35,6 +35,8 @@ def MockEvent(**kwargs):
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
+ if "content" not in kwargs:
+ kwargs["content"] = {}
return make_event_from_dict(kwargs)
@@ -357,6 +359,66 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertTrue(Filter(self.hs, definition)._check(event))
+ @unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
+ def test_filter_rel_type(self):
+ definition = {"org.matrix.msc3874.rel_types": ["m.thread"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={},
+ )
+
+ self.assertFalse(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}},
+ )
+
+ self.assertFalse(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}},
+ )
+
+ self.assertTrue(Filter(self.hs, definition)._check(event))
+
+ @unittest.override_config({"experimental_features": {"msc3874_enabled": True}})
+ def test_filter_not_rel_type(self):
+ definition = {"org.matrix.msc3874.not_rel_types": ["m.thread"]}
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.thread"}},
+ )
+
+ self.assertFalse(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={},
+ )
+
+ self.assertTrue(Filter(self.hs, definition)._check(event))
+
+ event = MockEvent(
+ sender="@foo:bar",
+ type="m.room.message",
+ room_id="!secretbase:unknown",
+ content={"m.relates_to": {"event_id": "$abc", "rel_type": "m.reference"}},
+ )
+
+ self.assertTrue(Filter(self.hs, definition)._check(event))
+
def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}}
filter_id = self.get_success(
@@ -456,7 +518,6 @@ class FilteringTestCase(unittest.HomeserverTestCase):
self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_filter_relations(self):
events = [
# An event without a relation.
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index c7dae58eb5..8d03da7f96 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
self.assertEqual(channel.code, 401)
-@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock())
+@patch("synapse.app.homeserver.KeyResource", new=Mock())
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 532b676365..11008ac1fb 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -69,10 +69,14 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
self.request_url = None
- async def get_json(url: str, args: Mapping[Any, Any]) -> List[JsonDict]:
- if not args.get(b"access_token"):
+ async def get_json(
+ url: str, args: Mapping[Any, Any], headers: Mapping[Any, Any]
+ ) -> List[JsonDict]:
+ # Ensure the access token is passed as both a header and query arg.
+ if not headers.get("Authorization") or not args.get(b"access_token"):
raise RuntimeError("Access token not provided")
+ self.assertEqual(headers.get("Authorization"), f"Bearer {TOKEN}")
self.assertEqual(args.get(b"access_token"), TOKEN)
self.request_url = url
if url == URL_USER:
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 50e376f695..a538215931 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -23,14 +23,23 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
+from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase
class FederationClientTest(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
super().prepare(reactor, clock, homeserver)
@@ -231,6 +240,72 @@ class FederationClientTest(FederatingHomeserverTestCase):
return remote_pdu
+ def test_backfill_invalid_signature_records_failed_pull_attempts(
+ self,
+ ) -> None:
+ """
+ Test to make sure that events from /backfill with invalid signatures get
+ recorded as failed pull attempts.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # We purposely don't run `add_hashes_and_signatures_from_other_server`
+ # over this because we want the signature check to fail.
+ pulled_event, _ = self.get_success(
+ event_injection.create_event(
+ self.hs,
+ room_id=room_id,
+ sender=OTHER_USER,
+ type="test_event_type",
+ content={"body": "garply"},
+ )
+ )
+
+ # We expect an outbound request to /backfill, so stub that out
+ self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
+ _mock_response(
+ {
+ "origin": "yet.another.server",
+ "origin_server_ts": 900,
+ # Mimic the other server returning our new `pulled_event`
+ "pdus": [pulled_event.get_pdu_json()],
+ }
+ )
+ )
+
+ self.get_success(
+ self.hs.get_federation_client().backfill(
+ # We use "yet.another.server" instead of
+ # `self.OTHER_SERVER_NAME` because we want to see the behavior
+ # from `_check_sigs_and_hash_and_fetch_one` where it tries to
+ # fetch the PDU again from the origin server if the signature
+ # fails. Just want to make sure that the failure is counted from
+ # both code paths.
+ dest="yet.another.server",
+ room_id=room_id,
+ limit=1,
+ extremities=[pulled_event.event_id],
+ ),
+ )
+
+ # Make sure our failed pull attempt was recorded
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
+ # other from "yet.another.server"
+ self.assertEqual(backfill_num_attempts, 2)
+
def _mock_response(resp: JsonDict):
body = json.dumps(resp).encode("utf-8")
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index c2320ce133..dd4d1b56de 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -13,9 +13,11 @@
# limitations under the License.
import json
+from unittest.mock import Mock
from synapse.api.room_versions import RoomVersions
from synapse.federation.transport.client import SendJoinParser
+from synapse.util import ExceptionBundle
from tests.unittest import TestCase
@@ -94,3 +96,37 @@ class SendJoinParserTestCase(TestCase):
# Retrieve and check the parsed SendJoinResponse
parsed_response = parser.finish()
self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"])
+
+ def test_errors_closing_coroutines(self) -> None:
+ """Check we close all coroutines, even if closing the first raises an Exception.
+
+ We also check that an Exception of some kind is raised, but we don't make any
+ assertions about its attributes or type.
+ """
+ parser = SendJoinParser(RoomVersions.V1, False)
+ response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]}
+ serialisation = json.dumps(response).encode()
+
+ # Mock the coroutines managed by this parser.
+ # The first one will error when we try to close it.
+ coro_1 = Mock()
+ coro_1.close = Mock(side_effect=RuntimeError("Couldn't close coro 1"))
+
+ coro_2 = Mock()
+
+ coro_3 = Mock()
+ coro_3.close = Mock(side_effect=RuntimeError("Couldn't close coro 3"))
+
+ parser._coros = [coro_1, coro_2, coro_3]
+
+ # Send half of the data to the parser
+ parser.write(serialisation[: len(serialisation) // 2])
+
+ # Close the parser. There should be _some_ kind of exception.
+ with self.assertRaises(ExceptionBundle):
+ parser.finish()
+
+ # In any case, we should have tried to close both coros.
+ coro_1.close.assert_called()
+ coro_2.close.assert_called()
+ coro_3.close.assert_called()
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index af24c4984d..7e4570f990 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [], {})),
- make_awaitable((1, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {})),
+ make_awaitable((1, {event.event_id: 0})),
+ ]
+ self.mock_store.get_events_as_list.side_effect = [
+ make_awaitable([]),
+ make_awaitable([event]),
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
@@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {event.event_id: 0})),
]
-
+ self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})),
]
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 745750b1d7..d00c69c229 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -19,7 +19,13 @@ from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ LimitExceededError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
@@ -28,6 +34,7 @@ from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -322,6 +329,102 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
self.get_success(d)
+ def test_backfill_ignores_known_events(self) -> None:
+ """
+ Tests that events that we already know about are ignored when backfilling.
+ """
+ # Set up users
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # Create a room to backfill events into
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # Build an event to backfill
+ event = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"body": "hello world", "msgtype": "m.text"},
+ "room_id": room_id,
+ "sender": other_user,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ # Ensure the event is not already in the DB
+ self.get_failure(
+ self.store.get_event(event.event_id),
+ NotFoundError,
+ )
+
+ # Backfill the event and check that it has entered the DB.
+
+ # We mock out the FederationClient.backfill method, to pretend that a remote
+ # server has returned our fake event.
+ federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock
+
+ # We also mock the persist method with a side effect of itself. This allows us
+ # to track when it has been called while preserving its function.
+ persist_events_and_notify_mock = Mock(
+ side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
+ )
+ self.hs.get_federation_event_handler().persist_events_and_notify = (
+ persist_events_and_notify_mock
+ )
+
+ # Small side-tangent. We populate the event cache with the event, even though
+ # it is not yet in the DB. This is an invalid scenario that can currently occur
+ # due to not properly invalidating the event cache.
+ # See https://github.com/matrix-org/synapse/issues/13476.
+ #
+ # As a result, backfill should not rely on the event cache to check whether
+ # we already have an event in the DB.
+ # TODO: Remove this bit when the event cache is properly invalidated.
+ cache_entry = EventCacheEntry(
+ event=event,
+ redacted_event=None,
+ )
+ self.store._get_event_cache.set_local((event.event_id,), cache_entry)
+
+ # We now call FederationEventHandler.backfill (a separate method) to trigger
+ # a backfill request. It should receive the fake event.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ )
+ )
+
+ # Check that our fake event was persisted.
+ persist_events_and_notify_mock.assert_called_once()
+ persist_events_and_notify_mock.reset_mock()
+
+ # Now we repeat the backfill, having the homeserver receive the fake event
+ # again.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ ),
+ )
+
+ # This time, we expect no event persistence to have occurred, as we already
+ # have this event.
+ persist_events_and_notify_mock.assert_not_called()
+
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 918010cddb..e448cb1901 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -14,7 +14,7 @@
from typing import Optional
from unittest import mock
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, StoreError
from synapse.api.room_versions import RoomVersion
from synapse.event_auth import (
check_state_dependent_auth_rules,
@@ -43,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
def make_homeserver(self, reactor, clock):
# mock out the federation transport client
self.mock_federation_transport_client = mock.Mock(
- spec=["get_room_state_ids", "get_room_state", "get_event"]
+ spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
)
return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client
@@ -459,6 +459,203 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
self.assertIsNotNone(persisted, "pulled event was not persisted at all")
+ def test_backfill_signature_failure_does_not_fetch_same_prev_event_later(
+ self,
+ ) -> None:
+ """
+ Test to make sure we backoff and don't try to fetch a missing prev_event when we
+ already know it has a invalid signature from checking the signatures of all of
+ the events in the backfill response.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # Allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # Add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ member_event.event_id,
+ ]
+
+ # We purposely don't run `add_hashes_and_signatures_from_other_server`
+ # over this because we want the signature check to fail.
+ pulled_event_without_signatures = make_event_from_dict(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [member_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event_without_signatures"},
+ },
+ room_version,
+ )
+
+ # Create a regular event that should pass except for the
+ # `pulled_event_without_signatures` in the `prev_event`.
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ member_event.event_id,
+ pulled_event_without_signatures.event_id,
+ ],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event"},
+ }
+ ),
+ room_version,
+ )
+
+ # We expect an outbound request to /backfill, so stub that out
+ self.mock_federation_transport_client.backfill.return_value = make_awaitable(
+ {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ # This is one of the important aspects of this test: we include
+ # `pulled_event_without_signatures` so it fails the signature check
+ # when we filter down the backfill response down to events which
+ # have valid signatures in
+ # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+ pulled_event_without_signatures.get_pdu_json(),
+ # Then later when we process this valid signature event, when we
+ # fetch the missing `prev_event`s, we want to make sure that we
+ # backoff and don't try and fetch `pulled_event_without_signatures`
+ # again since we know it just had an invalid signature.
+ pulled_event.get_pdu_json(),
+ ],
+ }
+ )
+
+ # Keep track of the count and make sure we don't make any of these requests
+ event_endpoint_requested_count = 0
+ room_state_ids_endpoint_requested_count = 0
+ room_state_endpoint_requested_count = 0
+
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> None:
+ nonlocal event_endpoint_requested_count
+ event_endpoint_requested_count += 1
+
+ async def get_room_state_ids(
+ destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_ids_endpoint_requested_count
+ room_state_ids_endpoint_requested_count += 1
+
+ async def get_room_state(
+ room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_endpoint_requested_count
+ room_state_endpoint_requested_count += 1
+
+ # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in
+ # the happy path but if the logic is sneaking around what we expect, stub that
+ # out so we can detect that failure
+ self.mock_federation_transport_client.get_event.side_effect = get_event
+ self.mock_federation_transport_client.get_room_state_ids.side_effect = (
+ get_room_state_ids
+ )
+ self.mock_federation_transport_client.get_room_state.side_effect = (
+ get_room_state
+ )
+
+ # The function under test: try to backfill and process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ self.OTHER_SERVER_NAME,
+ room_id,
+ limit=1,
+ extremities=["$some_extremity"],
+ )
+ )
+
+ if event_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /event in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_ids_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state_ids in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ # Make sure we only recorded a single failure which corresponds to the signature
+ # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before
+ # we process all of the pulled events.
+ backfill_num_attempts_for_event_without_signatures = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event_without_signatures.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1)
+
+ # And make sure we didn't record a failure for the event that has the missing
+ # prev_event because we don't want to cause a cascade of failures. Not being
+ # able to fetch the `prev_events` just means we won't be able to de-outlier the
+ # pulled event. But we can still use an `outlier` in the state/auth chain for
+ # another event. So we shouldn't stop a downstream event from trying to pull it.
+ self.get_failure(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ ),
+ # StoreError: 404: No row found
+ StoreError,
+ )
+
def test_process_pulled_event_with_rejected_missing_state(self) -> None:
"""Ensure that we correctly handle pulled events with missing state containing a
rejected state event
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index c8cc21cadd..a801f002a0 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -25,6 +25,8 @@ class ServerNameTestCase(unittest.TestCase):
"[0abc:1def::1234]": ("[0abc:1def::1234]", None),
"1.2.3.4:1": ("1.2.3.4", 1),
"[0abc:1def::1234]:8080": ("[0abc:1def::1234]", 8080),
+ ":80": ("", 80),
+ "": ("", None),
}
for i, o in test_data.items():
@@ -42,6 +44,7 @@ class ServerNameTestCase(unittest.TestCase):
"newline.com\n",
".empty-label.com",
"1234:5678:80", # too many colons
+ ":80",
]
for i in test_data:
try:
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 3cbca0f5a3..46166292fe 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -35,11 +35,13 @@ from tests.http.server._base import test_disconnect
def make_request(content):
"""Make an object that acts enough like a request."""
- request = Mock(spec=["content"])
+ request = Mock(spec=["method", "uri", "content"])
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
+ request.method = bytes("STUB_METHOD", "ascii")
+ request.uri = bytes("/test_stub_uri", "ascii")
request.content = BytesIO(content)
return request
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 96f399b7ab..0b0d8737c1 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -153,6 +153,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site.site_tag = "test-site"
site.server_version_string = "Server v1"
site.reactor = Mock()
+ site.experimental_cors_msc3886 = False
request = SynapseRequest(FakeChannel(site, None), site)
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
new file mode 100644
index 0000000000..675d7df2ac
--- /dev/null
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -0,0 +1,74 @@
+from unittest.mock import patch
+
+from synapse.api.room_versions import RoomVersions
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
+from synapse.rest import admin
+from synapse.rest.client import login, register, room
+from synapse.types import create_requester
+
+from tests import unittest
+
+
+class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ ]
+
+ def test_action_for_event_by_user_handles_noninteger_power_levels(self) -> None:
+ """We should convert floats and strings to integers before passing to Rust.
+
+ Reproduces #14060.
+
+ A lack of validation: the gift that keeps on giving.
+ """
+ # Create a new user and room.
+ alice = self.register_user("alice", "pass")
+ token = self.login(alice, "pass")
+
+ room_id = self.helper.create_room_as(
+ alice, room_version=RoomVersions.V9.identifier, tok=token
+ )
+
+ # Alter the power levels in that room to include stringy and floaty levels.
+ # We need to suppress the validation logic or else it will reject these dodgy
+ # values. (Presumably this validation was not always present.)
+ event_creation_handler = self.hs.get_event_creation_handler()
+ requester = create_requester(alice)
+ with patch("synapse.events.validator.validate_canonicaljson"), patch(
+ "synapse.events.validator.jsonschema.validate"
+ ):
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {
+ "users": {alice: "100"}, # stringy
+ "notifications": {"room": 100.0}, # float
+ },
+ token,
+ state_key="",
+ )
+
+ # Create a new message event, and try to evaluate it under the dodgy
+ # power level event.
+ event, context = self.get_success(
+ event_creation_handler.create_event(
+ requester,
+ {
+ "type": "m.room.message",
+ "room_id": room_id,
+ "content": {
+ "msgtype": "m.text",
+ "body": "helo",
+ },
+ "sender": alice,
+ },
+ )
+ )
+
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+ # should not raise
+ self.get_success(bulk_evaluator.action_for_event_by_user(event, context))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 8804f0e0d3..decf619466 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Set, Tuple, Union
+from typing import Dict, Optional, Union
import frozendict
@@ -38,12 +38,7 @@ from tests.test_utils.event_injection import create_event, inject_member_event
class PushRuleEvaluatorTestCase(unittest.TestCase):
- def _get_evaluator(
- self,
- content: JsonDict,
- relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None,
- relations_match_enabled: bool = False,
- ) -> PushRuleEvaluator:
+ def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator:
event = FrozenEvent(
{
"event_id": "$event_id",
@@ -63,8 +58,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_member_count,
sender_power_level,
power_levels.get("notifications", {}),
- relations or {},
- relations_match_enabled,
)
def test_display_name(self) -> None:
@@ -299,71 +292,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
{"sound": "default", "highlight": True},
)
- def test_relation_match(self) -> None:
- """Test the relation_match push rule kind."""
-
- # Check if the experimental feature is disabled.
- evaluator = self._get_evaluator(
- {}, {"m.annotation": {("@user:test", "m.reaction")}}
- )
-
- # A push rule evaluator with the experimental rule enabled.
- evaluator = self._get_evaluator(
- {}, {"m.annotation": {("@user:test", "m.reaction")}}, True
- )
-
- # Check just relation type.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check relation type and sender.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "sender": "@user:test",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "sender": "@other:test",
- }
- self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check relation type and event type.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "type": "m.reaction",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check just sender, this fails since rel_type is required.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "sender": "@user:test",
- }
- self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check sender glob.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "sender": "@*:test",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
-
- # Check event type glob.
- condition = {
- "kind": "org.matrix.msc3772.relation_match",
- "rel_type": "m.annotation",
- "event_type": "*.reaction",
- }
- self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
-
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
"""Tests for the bulk push rule evaluator"""
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index efd92793c0..d42e36cdf1 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -22,7 +22,10 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.event_push_actions import (
+ NotifCounts,
+ RoomNotifCounts,
+)
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition
@@ -178,7 +181,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
- NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
+ RoomNotifCounts(
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}
+ ),
)
self.persist(
@@ -191,7 +196,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
- NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
+ RoomNotifCounts(
+ NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}
+ ),
)
self.persist(
@@ -206,7 +213,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
- NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
+ RoomNotifCounts(
+ NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}
+ ),
)
def test_get_rooms_for_user_with_stream_ordering(self):
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index fef3b72d76..ddf315b894 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -654,6 +654,14 @@ class RelationsTestCase(BaseRelationsTestCase):
)
# We also expect to get the original event (the id of which is self.parent_id)
+ # when requesting the unstable endpoint.
+ self.assertNotIn("original_event", channel.json_body)
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(
channel.json_body["original_event"]["event_id"], self.parent_id
)
@@ -755,11 +763,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel.json_body["chunk"][0],
)
- # We also expect to get the original event (the id of which is self.parent_id)
- self.assertEqual(
- channel.json_body["original_event"]["event_id"], self.parent_id
- )
-
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
@@ -1674,7 +1677,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
{"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
)
- @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_redact_parent_thread(self) -> None:
"""
Test that thread replies are still available when the root event is redacted.
@@ -1704,3 +1706,165 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
related_event_id,
)
+
+
+class ThreadsTestCase(BaseRelationsTestCase):
+ def _get_threads(self, body: JsonDict) -> List[Tuple[str, str]]:
+ return [
+ (
+ ev["event_id"],
+ ev["unsigned"]["m.relations"]["m.thread"]["latest_event"]["event_id"],
+ )
+ for ev in body["chunk"]
+ ]
+
+ def test_threads(self) -> None:
+ """Create threads and ensure the ordering is due to their latest event."""
+ # Create 2 threads.
+ thread_1 = self.parent_id
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+ thread_2 = res["event_id"]
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", parent_id=thread_2
+ )
+ reply_2 = channel.json_body["event_id"]
+
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ threads = self._get_threads(channel.json_body)
+ self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
+
+ # Update the first thread, the ordering should swap.
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ reply_3 = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ # Tuple of (thread ID, latest event ID) for each thread.
+ threads = self._get_threads(channel.json_body)
+ self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
+
+ def test_pagination(self) -> None:
+ """Create threads and paginate through them."""
+ # Create 2 threads.
+ thread_1 = self.parent_id
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
+ thread_2 = res["event_id"]
+
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Request the threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_2])
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ next_batch = channel.json_body.get("next_batch")
+ self.assertIsInstance(next_batch, str, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_1], channel.json_body)
+
+ self.assertNotIn("next_batch", channel.json_body, channel.json_body)
+
+ def test_include(self) -> None:
+ """Filtering threads to all or participated in should work."""
+ # Thread 1 has the user as the root event.
+ thread_1 = self.parent_id
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+
+ # Thread 2 has the user replying.
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+ thread_2 = res["event_id"]
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Thread 3 has the user not participating in.
+ res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
+ thread_3 = res["event_id"]
+ self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.test",
+ access_token=self.user2_token,
+ parent_id=thread_3,
+ )
+
+ # All threads in the room.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(
+ thread_roots, [thread_3, thread_2, thread_1], channel.json_body
+ )
+
+ # Only participated threads.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
+
+ def test_ignored_user(self) -> None:
+ """Events from ignored users should be ignored."""
+ # Thread 1 has a reply from an ignored user.
+ thread_1 = self.parent_id
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+
+ # Thread 2 is created by an ignored user.
+ res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
+ thread_2 = res["event_id"]
+ self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
+
+ # Ignore user2.
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.user_id,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user2_id: {}}},
+ )
+ )
+
+ # Only thread 1 is returned.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/threads",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
+ self.assertEqual(thread_roots, [thread_1], channel.json_body)
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
new file mode 100644
index 0000000000..ad00a476e1
--- /dev/null
+++ b/tests/rest/client/test_rendezvous.py
@@ -0,0 +1,45 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest.client import rendezvous
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
+
+
+class RendezvousServletTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ rendezvous.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+ return self.hs
+
+ def test_disabled(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 400)
+
+ @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
+ def test_redirect(self) -> None:
+ channel = self.make_request("POST", endpoint, {}, access_token=None)
+ self.assertEqual(channel.code, 307)
+ self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"])
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 7f8cf4fab0..716366eb90 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
import json
from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call
+from unittest.mock import Mock, call, patch
from urllib import parse as urlparse
from parameterized import param, parameterized
@@ -35,13 +35,15 @@ from synapse.api.constants import (
EventTypes,
Membership,
PublicRoomsFilterFields,
- RelationTypes,
RoomTypes,
)
from synapse.api.errors import Codes, HttpResponseException
+from synapse.appservice import ApplicationService
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
-from synapse.rest.client import account, directory, login, profile, room, sync
+from synapse.rest.client import account, directory, login, profile, register, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock
@@ -49,7 +51,9 @@ from synapse.util.stringutils import random_string
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
+from tests.storage.test_stream import PaginationTestCase
from tests.test_utils import make_awaitable
+from tests.test_utils.event_injection import create_event
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -1252,6 +1256,120 @@ class RoomJoinTestCase(RoomBase):
)
+class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ room.register_servlets,
+ synapse.rest.admin.register_servlets,
+ register.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.appservice_user, _ = self.register_appservice_user(
+ "as_user_potato", self.appservice.token
+ )
+
+ # Create a room as the appservice user.
+ args = {
+ "access_token": self.appservice.token,
+ "user_id": self.appservice_user,
+ }
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/createRoom?{urlparse.urlencode(args)}",
+ content={"visibility": "public"},
+ )
+
+ assert channel.code == 200
+ self.room = channel.json_body["room_id"]
+
+ self.main_store = self.hs.get_datastores().main
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ self.appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+
+ mock_load_appservices = Mock(return_value=[self.appservice])
+ with patch(
+ "synapse.storage.databases.main.appservice.load_appservices",
+ mock_load_appservices,
+ ):
+ hs = self.setup_test_homeserver(config=config)
+ return hs
+
+ def test_send_event_ts(self) -> None:
+ """Test sending a non-state event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/send/m.room.message/1234?"
+ + urlparse.urlencode(url_params),
+ content={"body": "test", "msgtype": "m.text"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+ def test_send_state_event_ts(self) -> None:
+ """Test sending a state event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.name?"
+ + urlparse.urlencode(url_params),
+ content={"name": "test"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+ def test_send_membership_event_ts(self) -> None:
+ """Test sending a membership event with a custom timestamp."""
+ ts = 1
+
+ url_params = {
+ "user_id": self.appservice_user,
+ "ts": ts,
+ }
+ channel = self.make_request(
+ "PUT",
+ path=f"/_matrix/client/r0/rooms/{self.room}/state/m.room.member/{self.appservice_user}?"
+ + urlparse.urlencode(url_params),
+ content={"membership": "join", "display_name": "test"},
+ access_token=self.appservice.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ # Ensure the event was persisted with the correct timestamp.
+ res = self.get_success(self.main_store.get_event(event_id))
+ self.assertEquals(ts, res.origin_server_ts)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@@ -2098,14 +2216,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
)
def make_public_rooms_request(
- self, room_types: Union[List[Union[str, None]], None]
+ self,
+ room_types: Optional[List[Union[str, None]]],
+ instance_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
- channel = self.make_request(
- "POST",
- self.url,
- {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
- self.token,
- )
+ body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}
+ if instance_id:
+ body["third_party_instance_id"] = "test|test"
+
+ channel = self.make_request("POST", self.url, body, self.token)
+ self.assertEqual(channel.code, 200)
+
chunk = channel.json_body["chunk"]
count = channel.json_body["total_room_count_estimate"]
@@ -2115,31 +2236,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
chunk, count = self.make_public_rooms_request(None)
-
self.assertEqual(count, 2)
+ # Also check if there's no filter property at all in the body.
+ channel = self.make_request("POST", self.url, {}, self.token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["chunk"]), 2)
+ self.assertEqual(channel.json_body["total_room_count_estimate"], 2)
+
+ chunk, count = self.make_public_rooms_request(None, "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_rooms_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request([None])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), None)
+ chunk, count = self.make_public_rooms_request([None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_only_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space"])
self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), "m.space")
+ chunk, count = self.make_public_rooms_request(["m.space"], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space", None])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request(["m.space", None], "test|test")
+ self.assertEqual(count, 0)
+
def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
chunk, count = self.make_public_rooms_request([])
-
self.assertEqual(count, 2)
+ chunk, count = self.make_public_rooms_request([], "test|test")
+ self.assertEqual(count, 0)
+
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"""Test that we correctly fallback to local filtering if a remote server
@@ -2779,149 +2918,20 @@ class LabelsTestCase(unittest.HomeserverTestCase):
return event_id
-class RelationsTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["experimental_features"] = {"msc3440_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user_id = self.register_user("test", "test")
- self.tok = self.login("test", "test")
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- self.second_user_id = self.register_user("second", "test")
- self.second_tok = self.login("second", "test")
- self.helper.join(
- room=self.room_id, user=self.second_user_id, tok=self.second_tok
- )
-
- self.third_user_id = self.register_user("third", "test")
- self.third_tok = self.login("third", "test")
- self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
-
- # An initial event with a relation from second user.
- res = self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "Message 1"},
- tok=self.tok,
- )
- self.event_id_1 = res["event_id"]
- self.helper.send_event(
- room_id=self.room_id,
- type="m.reaction",
- content={
- "m.relates_to": {
- "rel_type": RelationTypes.ANNOTATION,
- "event_id": self.event_id_1,
- "key": "👍",
- }
- },
- tok=self.second_tok,
- )
-
- # Another event with a relation from third user.
- res = self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "Message 2"},
- tok=self.tok,
- )
- self.event_id_2 = res["event_id"]
- self.helper.send_event(
- room_id=self.room_id,
- type="m.reaction",
- content={
- "m.relates_to": {
- "rel_type": RelationTypes.REFERENCE,
- "event_id": self.event_id_2,
- }
- },
- tok=self.third_tok,
- )
-
- # An event with no relations.
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"msgtype": "m.text", "body": "No relations"},
- tok=self.tok,
- )
-
- def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
+class RelationsTestCase(PaginationTestCase):
+ def _filter_messages(self, filter: JsonDict) -> List[str]:
"""Make a request to /messages with a filter, returns the chunk of events."""
+ from_token = self.get_success(
+ self.from_token.to_string(self.hs.get_datastores().main)
+ )
channel = self.make_request(
"GET",
- "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
+ f"/rooms/{self.room_id}/messages?filter={json.dumps(filter)}&dir=f&from={from_token}",
access_token=self.tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- return channel.json_body["chunk"]
-
- def test_filter_relation_senders(self) -> None:
- # Messages which second user reacted to.
- filter = {"related_by_senders": [self.second_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
-
- # Messages which third user reacted to.
- filter = {"related_by_senders": [self.third_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_2)
-
- # Messages which either user reacted to.
- filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
- )
-
- def test_filter_relation_type(self) -> None:
- # Messages which have annotations.
- filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
-
- # Messages which have references.
- filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_2)
-
- # Messages which have either annotations or references.
- filter = {
- "related_by_rel_types": [
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- ]
- }
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
- )
-
- def test_filter_relation_senders_and_type(self) -> None:
- # Messages which second user reacted to.
- filter = {
- "related_by_senders": [self.second_user_id],
- "related_by_rel_types": [RelationTypes.ANNOTATION],
- }
- chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0]["event_id"], self.event_id_1)
+ return [ev["event_id"] for ev in channel.json_body["chunk"]]
class ContextTestCase(unittest.HomeserverTestCase):
@@ -3479,3 +3489,65 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
+
+
+class TimestampLookupTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = {"msc3030_enabled": True}
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self._storage_controllers = self.hs.get_storage_controllers()
+
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ def _inject_outlier(self, room_id: str) -> EventBase:
+ event, _context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=room_id,
+ type="m.test",
+ sender="@test_remote_user:remote",
+ )
+ )
+
+ event.internal_metadata.outlier = True
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
+ )
+ )
+ return event
+
+ def test_no_outliers(self) -> None:
+ """
+ Test to make sure `/timestamp_to_event` does not return `outlier` events.
+ We're unable to determine whether an `outlier` is next to a gap so we
+ don't know whether it's actually the closest event. Instead, let's just
+ ignore `outliers` with this endpoint.
+
+ This test is really seeing that we choose the non-`outlier` event behind the
+ `outlier`. Since the gap checking logic considers the latest message in the room
+ as *not* next to a gap, asking over federation does not come into play here.
+ """
+ room_id = self.helper.create_room_as(self.room_owner, tok=self.room_owner_tok)
+
+ outlier_event = self._inject_outlier(room_id)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}",
+ access_token=self.room_owner_tok,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+
+ # Make sure the outlier event is not returned
+ self.assertNotEqual(channel.json_body["event_id"], outlier_event.event_id)
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index 61b66d7685..fdc433a8b5 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -59,7 +59,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.event_source.get_new_events(
user=UserID.from_string(self.user_id),
from_key=0,
- limit=None,
+ # Limit is unused.
+ limit=0,
room_ids=[self.room_id],
is_guest=False,
)
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index ac0ac06b7e..7f1fba1086 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest
-from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.key.v2 import KeyResource
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
@@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def create_test_resource(self) -> Resource:
return create_resource_tree(
- {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+ {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource()
)
def expect_outgoing_key_request(
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index f38d7225f8..319ae8b1cc 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -14,6 +14,8 @@
import json
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
@@ -23,8 +25,16 @@ from synapse.util import Clock
from tests.unittest import HomeserverTestCase
+try:
+ import lxml
+except ImportError:
+ lxml = None
+
class OEmbedTests(HomeserverTestCase):
+ if not lxml:
+ skip = "url preview feature requires lxml"
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.oembed = OEmbedProvider(hs)
@@ -36,7 +46,7 @@ class OEmbedTests(HomeserverTestCase):
def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1):
- result = self.parse_response({"version": version, "type": "link"})
+ result = self.parse_response({"version": version})
# An empty Open Graph response is an error, ensure the URL is included.
self.assertIn("og:url", result.open_graph_result)
@@ -49,3 +59,94 @@ class OEmbedTests(HomeserverTestCase):
result = self.parse_response({"version": version, "type": "link"})
# An empty Open Graph response is an error, ensure the URL is included.
self.assertEqual({}, result.open_graph_result)
+
+ def test_cache_age(self) -> None:
+ """Ensure a cache-age is parsed properly."""
+ # Correct-ish cache ages are allowed.
+ for cache_age in ("1", 1.0, 1):
+ result = self.parse_response({"cache_age": cache_age})
+ self.assertEqual(result.cache_age, 1000)
+
+ # Invalid cache ages are ignored.
+ for cache_age in ("invalid", {}):
+ result = self.parse_response({"cache_age": cache_age})
+ self.assertIsNone(result.cache_age)
+
+ # Cache age is optional.
+ result = self.parse_response({})
+ self.assertIsNone(result.cache_age)
+
+ @parameterized.expand(
+ [
+ ("title", "title"),
+ ("provider_name", "site_name"),
+ ("thumbnail_url", "image"),
+ ],
+ name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}",
+ )
+ def test_property(self, oembed_property: str, open_graph_property: str) -> None:
+ """Test properties which must be strings."""
+ result = self.parse_response({oembed_property: "test"})
+ self.assertIn(f"og:{open_graph_property}", result.open_graph_result)
+ self.assertEqual(result.open_graph_result[f"og:{open_graph_property}"], "test")
+
+ result = self.parse_response({oembed_property: 1})
+ self.assertNotIn(f"og:{open_graph_property}", result.open_graph_result)
+
+ def test_author_name(self) -> None:
+ """Test the author_name property."""
+ result = self.parse_response({"author_name": "test"})
+ self.assertEqual(result.author_name, "test")
+
+ result = self.parse_response({"author_name": 1})
+ self.assertIsNone(result.author_name)
+
+ def test_rich(self) -> None:
+ """Test a type of rich."""
+ result = self.parse_response({"html": "test<img src='foo'>", "type": "rich"})
+ self.assertIn("og:description", result.open_graph_result)
+ self.assertIn("og:image", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:description"], "test")
+ self.assertEqual(result.open_graph_result["og:image"], "foo")
+
+ result = self.parse_response({"type": "rich"})
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ result = self.parse_response({"html": 1, "type": "rich"})
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ def test_photo(self) -> None:
+ """Test a type of photo."""
+ result = self.parse_response({"url": "test", "type": "photo"})
+ self.assertIn("og:image", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:image"], "test")
+
+ result = self.parse_response({"type": "photo"})
+ self.assertNotIn("og:image", result.open_graph_result)
+
+ result = self.parse_response({"url": 1, "type": "photo"})
+ self.assertNotIn("og:image", result.open_graph_result)
+
+ def test_video(self) -> None:
+ """Test a type of video."""
+ result = self.parse_response({"html": "test", "type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertIn("og:description", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:description"], "test")
+
+ result = self.parse_response({"type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ result = self.parse_response({"url": 1, "type": "video"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "video.other")
+ self.assertNotIn("og:description", result.open_graph_result)
+
+ def test_link(self) -> None:
+ """Test type of link."""
+ result = self.parse_response({"type": "link"})
+ self.assertIn("og:type", result.open_graph_result)
+ self.assertEqual(result.open_graph_result["og:type"], "website")
diff --git a/tests/server.py b/tests/server.py
index c447d5e4c4..8b1d186219 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -266,7 +266,12 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
- def __init__(self, resource: IResource, reactor: IReactorTime):
+ def __init__(
+ self,
+ resource: IResource,
+ reactor: IReactorTime,
+ experimental_cors_msc3886: bool = False,
+ ):
"""
Args:
@@ -274,6 +279,7 @@ class FakeSite:
"""
self._resource = resource
self.reactor = reactor
+ self.experimental_cors_msc3886 = experimental_cors_msc3886
def getResourceFor(self, request):
return self._resource
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 32a798d74b..5773172ab8 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -90,18 +90,6 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
- def test_query_via_event_cache(self):
- # fetch an event into the event cache
- self.get_success(self.store.get_event(self.event_ids[0]))
-
- # looking it up should now cause no db hits
- with LoggingContext(name="test") as ctx:
- res = self.get_success(
- self.store.have_seen_events(self.room_id, [self.event_ids[0]])
- )
- self.assertEqual(res, {self.event_ids[0]})
- self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
-
def test_persisting_event_invalidates_cache(self):
"""
Test to make sure that the `have_seen_event` cache
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 59b8910907..853db930d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -27,6 +27,8 @@ from synapse.api.room_versions import (
RoomVersion,
)
from synapse.events import _EventInternalMetadata
+from synapse.rest import admin
+from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
@@ -43,6 +45,12 @@ class _BackfillSetupInfo:
class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
@@ -1122,6 +1130,62 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
+ def test_get_event_ids_to_not_pull_from_backoff(
+ self,
+ ):
+ """
+ Test to make sure only event IDs we should backoff from are returned.
+ """
+ # Create the room
+ user_id = self.register_user("alice", "test")
+ tok = self.login("alice", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "$failed_event_id", "fake cause"
+ )
+ )
+
+ event_ids_to_backoff = self.get_success(
+ self.store.get_event_ids_to_not_pull_from_backoff(
+ room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
+ )
+ )
+
+ self.assertEqual(event_ids_to_backoff, ["$failed_event_id"])
+
+ def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
+ self,
+ ):
+ """
+ Test to make sure no event IDs are returned after the backoff duration has
+ elapsed.
+ """
+ # Create the room
+ user_id = self.register_user("alice", "test")
+ tok = self.login("alice", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "$failed_event_id", "fake cause"
+ )
+ )
+
+ # Now advance time by 2 hours so we wait long enough for the single failed
+ # attempt (2^1 hours).
+ self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+ event_ids_to_backoff = self.get_success(
+ self.store.get_event_ids_to_not_pull_from_backoff(
+ room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
+ )
+ )
+ # Since this function only returns events we should backoff from, time has
+ # elapsed past the backoff range so there is no events to backoff from.
+ self.assertEqual(event_ids_to_backoff, [])
+
@attr.s
class FakeEvent:
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 473c965e19..ee48920f84 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple
+from typing import Optional, Tuple
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -64,16 +66,23 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
user_id, token, _, other_token, room_id = self._create_users_and_room()
# Create two events, one of which is a highlight.
- self.helper.send_event(
+ first_event_id = self.helper.send_event(
room_id,
type="m.room.message",
content={"msgtype": "m.text", "body": "msg"},
tok=other_token,
- )
- event_id = self.helper.send_event(
+ )["event_id"]
+ second_event_id = self.helper.send_event(
room_id,
type="m.room.message",
- content={"msgtype": "m.text", "body": user_id},
+ content={
+ "msgtype": "m.text",
+ "body": user_id,
+ "m.relates_to": {
+ "rel_type": RelationTypes.THREAD,
+ "event_id": first_event_id,
+ },
+ },
tok=other_token,
)["event_id"]
@@ -93,13 +102,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
self.assertEqual(2, len(email_actions))
- # Send a receipt, which should clear any actions.
+ # Send a receipt, which should clear the first action.
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
- event_ids=[event_id],
+ event_ids=[first_event_id],
thread_id=None,
data={},
)
@@ -109,6 +118,30 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
user_id, 0, 1000, 20
)
)
+ self.assertEqual(1, len(http_actions))
+ email_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual(1, len(email_actions))
+
+ # Send a thread receipt to clear the thread action.
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[second_event_id],
+ thread_id=first_event_id,
+ data={},
+ )
+ )
+ http_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ user_id, 0, 1000, 20
+ )
+ )
self.assertEqual([], http_actions)
email_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
@@ -133,13 +166,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(
- counts,
+ counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
unread_count=0,
highlight_count=highlight_count,
),
)
+ self.assertEqual(counts.threads, {})
def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
@@ -186,6 +220,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_assert_counts(0, 0)
_create_event()
+ _assert_counts(1, 0)
_rotate()
_assert_counts(1, 0)
@@ -236,6 +271,444 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_rotate()
_assert_counts(0, 0)
+ def test_count_aggregation_threads(self) -> None:
+ """
+ This is essentially the same test as test_count_aggregation, but adds
+ events to the main timeline and to a thread.
+ """
+
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+ thread_id: str
+
+ last_event_id: str
+
+ def _assert_counts(
+ noitf_count: int,
+ highlight_count: int,
+ thread_notif_count: int,
+ thread_highlight_count: int,
+ ) -> None:
+ counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ counts.main_timeline,
+ NotifCounts(
+ notify_count=noitf_count,
+ unread_count=0,
+ highlight_count=highlight_count,
+ ),
+ )
+ if thread_notif_count or thread_highlight_count:
+ self.assertEqual(
+ counts.threads,
+ {
+ thread_id: NotifCounts(
+ notify_count=thread_notif_count,
+ unread_count=0,
+ highlight_count=thread_highlight_count,
+ ),
+ },
+ )
+ else:
+ self.assertEqual(counts.threads, {})
+
+ def _create_event(
+ highlight: bool = False, thread_id: Optional[str] = None
+ ) -> str:
+ content: JsonDict = {
+ "msgtype": "m.text",
+ "body": user_id if highlight else "msg",
+ }
+ if thread_id:
+ content["m.relates_to"] = {
+ "rel_type": "m.thread",
+ "event_id": thread_id,
+ }
+
+ result = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content=content,
+ tok=other_token,
+ )
+ nonlocal last_event_id
+ last_event_id = result["event_id"]
+ return last_event_id
+
+ def _rotate() -> None:
+ self.get_success(self.store._rotate_notifs())
+
+ def _mark_read(event_id: str, thread_id: str = MAIN_TIMELINE) -> None:
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[event_id],
+ thread_id=thread_id,
+ data={},
+ )
+ )
+
+ _assert_counts(0, 0, 0, 0)
+ thread_id = _create_event()
+ _assert_counts(1, 0, 0, 0)
+ _rotate()
+ _assert_counts(1, 0, 0, 0)
+
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ _create_event()
+ _assert_counts(2, 0, 1, 0)
+ _rotate()
+ _assert_counts(2, 0, 1, 0)
+
+ event_id = _create_event(thread_id=thread_id)
+ _assert_counts(2, 0, 2, 0)
+ _rotate()
+ _assert_counts(2, 0, 2, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _mark_read(event_id)
+ _assert_counts(1, 0, 3, 0)
+ _mark_read(event_id, thread_id)
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ # Delete old event push actions, this should not affect the (summarised) count.
+ self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _assert_counts(1, 1, 0, 0)
+ _rotate()
+ _assert_counts(1, 1, 0, 0)
+
+ event_id = _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _rotate()
+ _assert_counts(1, 1, 1, 1)
+
+ # Check that adding another notification and rotating after highlight
+ # works.
+ _create_event()
+ _rotate()
+ _assert_counts(2, 1, 1, 1)
+
+ _create_event(thread_id=thread_id)
+ _rotate()
+ _assert_counts(2, 1, 2, 1)
+
+ # Check that sending read receipts at different points results in the
+ # right counts.
+ _mark_read(event_id)
+ _assert_counts(1, 0, 2, 1)
+ _mark_read(event_id, thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 1, 0)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _mark_read(last_event_id)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+ _rotate()
+ _assert_counts(0, 0, 0, 0)
+
+ def test_count_aggregation_mixed(self) -> None:
+ """
+ This is essentially the same test as test_count_aggregation_threads, but
+ sends both unthreaded and threaded receipts.
+ """
+
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+ thread_id: str
+
+ last_event_id: str
+
+ def _assert_counts(
+ noitf_count: int,
+ highlight_count: int,
+ thread_notif_count: int,
+ thread_highlight_count: int,
+ ) -> None:
+ counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ counts.main_timeline,
+ NotifCounts(
+ notify_count=noitf_count,
+ unread_count=0,
+ highlight_count=highlight_count,
+ ),
+ )
+ if thread_notif_count or thread_highlight_count:
+ self.assertEqual(
+ counts.threads,
+ {
+ thread_id: NotifCounts(
+ notify_count=thread_notif_count,
+ unread_count=0,
+ highlight_count=thread_highlight_count,
+ ),
+ },
+ )
+ else:
+ self.assertEqual(counts.threads, {})
+
+ def _create_event(
+ highlight: bool = False, thread_id: Optional[str] = None
+ ) -> str:
+ content: JsonDict = {
+ "msgtype": "m.text",
+ "body": user_id if highlight else "msg",
+ }
+ if thread_id:
+ content["m.relates_to"] = {
+ "rel_type": "m.thread",
+ "event_id": thread_id,
+ }
+
+ result = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content=content,
+ tok=other_token,
+ )
+ nonlocal last_event_id
+ last_event_id = result["event_id"]
+ return last_event_id
+
+ def _rotate() -> None:
+ self.get_success(self.store._rotate_notifs())
+
+ def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[event_id],
+ thread_id=thread_id,
+ data={},
+ )
+ )
+
+ _assert_counts(0, 0, 0, 0)
+ thread_id = _create_event()
+ _assert_counts(1, 0, 0, 0)
+ _rotate()
+ _assert_counts(1, 0, 0, 0)
+
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ _create_event()
+ _assert_counts(2, 0, 1, 0)
+ _rotate()
+ _assert_counts(2, 0, 1, 0)
+
+ event_id = _create_event(thread_id=thread_id)
+ _assert_counts(2, 0, 2, 0)
+ _rotate()
+ _assert_counts(2, 0, 2, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _mark_read(event_id)
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id, MAIN_TIMELINE)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event()
+ _create_event(thread_id=thread_id)
+ _assert_counts(1, 0, 1, 0)
+ _rotate()
+ _assert_counts(1, 0, 1, 0)
+
+ # Delete old event push actions, this should not affect the (summarised) count.
+ self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ _assert_counts(1, 0, 1, 0)
+
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _assert_counts(1, 1, 0, 0)
+ _rotate()
+ _assert_counts(1, 1, 0, 0)
+
+ event_id = _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _rotate()
+ _assert_counts(1, 1, 1, 1)
+
+ # Check that adding another notification and rotating after highlight
+ # works.
+ _create_event()
+ _rotate()
+ _assert_counts(2, 1, 1, 1)
+
+ _create_event(thread_id=thread_id)
+ _rotate()
+ _assert_counts(2, 1, 2, 1)
+
+ # Check that sending read receipts at different points results in the
+ # right counts.
+ _mark_read(event_id)
+ _assert_counts(1, 0, 1, 0)
+ _mark_read(event_id, MAIN_TIMELINE)
+ _assert_counts(1, 0, 1, 0)
+ _mark_read(last_event_id, MAIN_TIMELINE)
+ _assert_counts(0, 0, 1, 0)
+ _mark_read(last_event_id, thread_id)
+ _assert_counts(0, 0, 0, 0)
+
+ _create_event(True)
+ _create_event(True, thread_id)
+ _assert_counts(1, 1, 1, 1)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0, 0, 0)
+ _rotate()
+ _assert_counts(0, 0, 0, 0)
+
+ def test_recursive_thread(self) -> None:
+ """
+ Events related to events in a thread should still be considered part of
+ that thread.
+ """
+
+ # Create a user to receive notifications and send receipts.
+ user_id = self.register_user("user1235", "pass")
+ token = self.login("user1235", "pass")
+
+ # And another users to send events.
+ other_id = self.register_user("other", "pass")
+ other_token = self.login("other", "pass")
+
+ # Create a room and put both users in it.
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ self.helper.join(room_id, other_id, tok=other_token)
+
+ # Update the user's push rules to care about reaction events.
+ self.get_success(
+ self.store.add_push_rule(
+ user_id,
+ "related_events",
+ priority_class=5,
+ conditions=[
+ {"kind": "event_match", "key": "type", "pattern": "m.reaction"}
+ ],
+ actions=["notify"],
+ )
+ )
+
+ def _create_event(type: str, content: JsonDict) -> str:
+ result = self.helper.send_event(
+ room_id, type=type, content=content, tok=other_token
+ )
+ return result["event_id"]
+
+ def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
+ counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ counts.main_timeline,
+ NotifCounts(
+ notify_count=noitf_count, unread_count=0, highlight_count=0
+ ),
+ )
+ if thread_notif_count:
+ self.assertEqual(
+ counts.threads,
+ {
+ thread_id: NotifCounts(
+ notify_count=thread_notif_count,
+ unread_count=0,
+ highlight_count=0,
+ ),
+ },
+ )
+ else:
+ self.assertEqual(counts.threads, {})
+
+ # Create a root event.
+ thread_id = _create_event(
+ "m.room.message", {"msgtype": "m.text", "body": "msg"}
+ )
+ _assert_counts(1, 0)
+
+ # Reply, creating a thread.
+ reply_id = _create_event(
+ "m.room.message",
+ {
+ "msgtype": "m.text",
+ "body": "msg",
+ "m.relates_to": {
+ "rel_type": "m.thread",
+ "event_id": thread_id,
+ },
+ },
+ )
+ _assert_counts(1, 1)
+
+ # Create an event related to a thread event, this should still appear in
+ # the thread.
+ _create_event(
+ type="m.reaction",
+ content={
+ "m.relates_to": {
+ "rel_type": "m.annotation",
+ "event_id": reply_id,
+ "key": "A",
+ }
+ },
+ )
+ _assert_counts(1, 2)
+
def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
self.get_success(
diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py
new file mode 100644
index 0000000000..cd1d00208b
--- /dev/null
+++ b/tests/storage/test_relations.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import MAIN_TIMELINE
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class RelationsStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ """
+ Creates a DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ F <--[m.annotation]-- G
+
+ """
+ self._main_store = self.hs.get_datastores().main
+
+ self._create_relation("A", "B", "m.thread")
+ self._create_relation("B", "C", "m.annotation")
+ self._create_relation("A", "D", "m.reference")
+ self._create_relation("D", "E", "m.annotation")
+ self._create_relation("F", "G", "m.annotation")
+
+ def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None:
+ self.get_success(
+ self._main_store.db_pool.simple_insert(
+ table="event_relations",
+ values={
+ "event_id": event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ },
+ )
+ )
+
+ def test_get_thread_id(self) -> None:
+ """
+ Ensure that get_thread_id only searches up the tree for threads.
+ """
+ # The thread itself and children of it return the thread.
+ thread_id = self.get_success(self._main_store.get_thread_id("B"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("C"))
+ self.assertEqual("A", thread_id)
+
+ # But the root and events related to the root do not.
+ thread_id = self.get_success(self._main_store.get_thread_id("A"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("D"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("E"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ # Events which are not related to a thread at all should return the
+ # main timeline.
+ thread_id = self.get_success(self._main_store.get_thread_id("F"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("G"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ def test_get_thread_id_for_receipts(self) -> None:
+ """
+ Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
+ """
+ # All of the events are considered related to this thread.
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
+ self.assertEqual("A", thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
+ self.assertEqual("A", thread_id)
+
+ # Events which are not related to a thread at all should return the
+ # main timeline.
+ thread_id = self.get_success(self._main_store.get_thread_id("F"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
+
+ thread_id = self.get_success(self._main_store.get_thread_id("G"))
+ self.assertEqual(MAIN_TIMELINE, thread_id)
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 78663a53fe..34fa810cf6 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,6 @@ from typing import List
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter
-from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.types import JsonDict
@@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase):
def default_config(self):
config = super().default_config()
- config["experimental_features"] = {"msc3440_enabled": True}
+ config["experimental_features"] = {"msc3874_enabled": True}
return config
def prepare(self, reactor, clock, homeserver):
@@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase):
self.third_tok = self.login("third", "test")
self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
+ # Store a token which is after all the room creation events.
+ self.from_token = self.get_success(
+ self.hs.get_event_sources().get_current_token_for_pagination(self.room_id)
+ )
+
# An initial event with a relation from second user.
res = self.helper.send_event(
room_id=self.room_id,
@@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase):
tok=self.tok,
)
self.event_id_1 = res["event_id"]
- self.helper.send_event(
+ res = self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
@@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase):
},
tok=self.second_tok,
)
+ self.event_id_annotation = res["event_id"]
# Another event with a relation from third user.
res = self.helper.send_event(
@@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase):
tok=self.tok,
)
self.event_id_2 = res["event_id"]
- self.helper.send_event(
+ res = self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
@@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase):
},
tok=self.third_tok,
)
+ self.event_id_reference = res["event_id"]
# An event with no relations.
- self.helper.send_event(
+ res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "No relations"},
tok=self.tok,
)
+ self.event_id_none = res["event_id"]
- def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
+ def _filter_messages(self, filter: JsonDict) -> List[str]:
"""Make a request to /messages with a filter, returns the chunk of events."""
- from_token = self.get_success(
- self.hs.get_event_sources().get_current_token_for_pagination(self.room_id)
- )
-
events, next_key = self.get_success(
self.hs.get_datastores().main.paginate_room_events(
room_id=self.room_id,
- from_key=from_token.room_key,
+ from_key=self.from_token.room_key,
to_key=None,
- direction="b",
+ direction="f",
limit=10,
event_filter=Filter(self.hs, filter),
)
)
- return events
+ return [ev.event_id for ev in events]
def test_filter_relation_senders(self):
# Messages which second user reacted to.
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
# Messages which third user reacted to.
filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_2)
+ self.assertEqual(chunk, [self.event_id_2])
# Messages which either user reacted to.
filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
- )
+ self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
def test_filter_relation_type(self):
# Messages which have annotations.
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
# Messages which have references.
filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_2)
+ self.assertEqual(chunk, [self.event_id_2])
# Messages which have either annotations or references.
filter = {
@@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase):
]
}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 2, chunk)
- self.assertCountEqual(
- [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
- )
+ self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
def test_filter_relation_senders_and_type(self):
# Messages which second user reacted to.
@@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase):
"related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
def test_duplicate_relation(self):
"""An event should only be returned once if there are multiple relations to it."""
@@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase):
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
- self.assertEqual(len(chunk), 1, chunk)
- self.assertEqual(chunk[0].event_id, self.event_id_1)
+ self.assertEqual(chunk, [self.event_id_1])
+
+ def test_filter_rel_types(self) -> None:
+ # Messages which are annotations.
+ filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(chunk, [self.event_id_annotation])
+
+ # Messages which are references.
+ filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(chunk, [self.event_id_reference])
+
+ # Messages which are either annotations or references.
+ filter = {
+ "org.matrix.msc3874.rel_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertCountEqual(
+ chunk,
+ [self.event_id_annotation, self.event_id_reference],
+ )
+
+ def test_filter_not_rel_types(self) -> None:
+ # Messages which are not annotations.
+ filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(
+ chunk,
+ [
+ self.event_id_1,
+ self.event_id_2,
+ self.event_id_reference,
+ self.event_id_none,
+ ],
+ )
+
+ # Messages which are not references.
+ filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]}
+ chunk = self._filter_messages(filter)
+ self.assertEqual(
+ chunk,
+ [
+ self.event_id_1,
+ self.event_id_annotation,
+ self.event_id_2,
+ self.event_id_none,
+ ],
+ )
+
+ # Messages which are neither annotations or references.
+ filter = {
+ "org.matrix.msc3874.not_rel_types": [
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ ]
+ }
+ chunk = self._filter_messages(filter)
+ self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none])
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 779fad1f63..80e5c590d8 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -86,8 +86,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client()
- self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
- pdus
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
+ lambda dest, pdus, **k: succeed(pdus)
)
# Send the join, it should return None (which is not an error)
diff --git a/tests/test_server.py b/tests/test_server.py
index 7c66448245..2d9a0257d4 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -222,13 +222,22 @@ class OptionsResourceTests(unittest.TestCase):
self.resource = OptionsResource()
self.resource.putChild(b"res", DummyResource())
- def _make_request(self, method: bytes, path: bytes) -> FakeChannel:
+ def _make_request(
+ self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False
+ ) -> FakeChannel:
"""Create a request from the method/path and return a channel with the response."""
# Create a site and query for the resource.
site = SynapseSite(
"test",
"site_tag",
- parse_listener_def(0, {"type": "http", "port": 0}),
+ parse_listener_def(
+ 0,
+ {
+ "type": "http",
+ "port": 0,
+ "experimental_cors_msc3886": experimental_cors_msc3886,
+ },
+ ),
self.resource,
"1.0",
max_request_body_size=4096,
@@ -239,25 +248,58 @@ class OptionsResourceTests(unittest.TestCase):
channel = make_request(self.reactor, site, method, path, shorthand=False)
return channel
+ def _check_cors_standard_headers(self, channel: FakeChannel) -> None:
+ # Ensure the correct CORS headers have been added
+ # as per https://spec.matrix.org/v1.4/client-server-api/#web-browser-clients
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
+ [b"*"],
+ "has correct CORS Origin header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"),
+ [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec
+ "has correct CORS Methods header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"),
+ [b"X-Requested-With, Content-Type, Authorization, Date"],
+ "has correct CORS Headers header",
+ )
+
+ def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None:
+ # Ensure the correct CORS headers have been added
+ # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
+ [b"*"],
+ "has correct CORS Origin header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"),
+ [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec
+ "has correct CORS Methods header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"),
+ [
+ b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match"
+ ],
+ "has correct CORS Headers header",
+ )
+ self.assertEqual(
+ channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"),
+ [b"ETag, Location, X-Max-Bytes"],
+ "has correct CORS Expose Headers header",
+ )
+
def test_unknown_options_request(self) -> None:
"""An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
- # Ensure the correct CORS headers have been added
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
- "has CORS Origin header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
- "has CORS Methods header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
- "has CORS Headers header",
- )
+ self._check_cors_standard_headers(channel)
def test_known_options_request(self) -> None:
"""An OPTIONS requests to an known URL still returns 204 No Content."""
@@ -265,19 +307,17 @@ class OptionsResourceTests(unittest.TestCase):
self.assertEqual(channel.code, 204)
self.assertNotIn("body", channel.result)
- # Ensure the correct CORS headers have been added
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
- "has CORS Origin header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
- "has CORS Methods header",
- )
- self.assertTrue(
- channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
- "has CORS Headers header",
+ self._check_cors_standard_headers(channel)
+
+ def test_known_options_request_msc3886(self) -> None:
+ """An OPTIONS requests to an known URL still returns 204 No Content."""
+ channel = self._make_request(
+ b"OPTIONS", b"/res/", experimental_cors_msc3886=True
)
+ self.assertEqual(channel.code, 204)
+ self.assertNotIn("body", channel.result)
+
+ self._check_cors_msc3886_headers(channel)
def test_unknown_request(self) -> None:
"""A non-OPTIONS request to an unknown URL should 404."""
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 90861fe522..78fd7b6961 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -1037,5 +1037,5 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj = Cls()
# Make sure this raises an error about the arg mismatch
- with self.assertRaises(Exception):
+ with self.assertRaises(TypeError):
obj.list_fn([("foo", "bar")])
|