diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index a5aa500ef8..f1e357764f 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
- "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
+ "room_id",
+ "m.read",
+ "user_id",
+ ["event_id"],
+ thread_id=None,
+ data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
- "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
+ "room_id",
+ "m.read",
+ "user_id",
+ ["event_id"],
+ thread_id=None,
+ data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
# send the second RR
receipt = ReadReceipt(
- "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
+ "room_id",
+ "m.read",
+ "user_id",
+ ["other_id"],
+ thread_id=None,
+ data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index b17af2725b..af24c4984d 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipt_type="m.read",
user_id=self.local_user,
event_ids=[f"$eventid_{i}"],
+ thread_id=None,
data={},
)
)
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 7b9b711521..bce65fab7d 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -15,11 +15,11 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes
-from synapse.push.baserules import PushRule
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest import admin
from synapse.rest.client import account, login
from synapse.server import HomeServer
+from synapse.synapse_rust.push import PushRule
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -161,20 +161,15 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self._store.get_push_rules_for_user(self.user)
)
# Filter out default rules; we don't care
- push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
+ push_rules = [
+ r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+ ]
# Check our rule made it
- self.assertEqual(
- push_rules,
- [
- PushRule(
- rule_id="personal.override.rule1",
- priority_class=5,
- conditions=[],
- actions=[],
- )
- ],
- push_rules,
- )
+ self.assertEqual(len(push_rules), 1)
+ self.assertEqual(push_rules[0].rule_id, "personal.override.rule1")
+ self.assertEqual(push_rules[0].priority_class, 5)
+ self.assertEqual(push_rules[0].conditions, [])
+ self.assertEqual(push_rules[0].actions, [])
# Request the deactivation of our account
self._deactivate_my_account()
@@ -183,7 +178,9 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self._store.get_push_rules_for_user(self.user)
)
# Filter out default rules; we don't care
- push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
+ push_rules = [
+ r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+ ]
# Check our rule no longer exists
self.assertEqual(push_rules, [], push_rules)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 8adba29d7f..9c821b3042 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -129,7 +129,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
- hs.get_event_auth_handler().check_host_in_room = check_host_in_room
+ hs.get_event_auth_handler().is_host_in_room = check_host_in_room
async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}
@@ -138,6 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
get_current_hosts_in_room
)
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+ get_current_hosts_in_room
+ )
+
async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 7a3b0d6755..fd14568f55 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
)
self.pusher = self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
kind="email",
@@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
"""
with self.assertRaises(SynapseError) as cm:
self.get_success_or_raise(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
kind="email",
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index d9c68cdd2d..b383b8401f 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -19,9 +19,10 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
-from synapse.push import PusherConfigException
-from synapse.rest.client import login, push_rule, receipts, room
+from synapse.push import PusherConfig, PusherConfigException
+from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -35,6 +36,7 @@ class HTTPPusherTests(HomeserverTestCase):
login.register_servlets,
receipts.register_servlets,
push_rule.register_servlets,
+ pusher.register_servlets,
]
user_id = True
hijack_auth = False
@@ -74,7 +76,7 @@ class HTTPPusherTests(HomeserverTestCase):
def test_data(data: Optional[JsonDict]) -> None:
self.get_failure(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -119,7 +121,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -235,7 +237,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -355,7 +357,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -441,7 +443,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -518,7 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -624,7 +626,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -728,18 +730,38 @@ class HTTPPusherTests(HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.json_body)
- def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
+ def _make_user_with_pusher(
+ self, username: str, enabled: bool = True
+ ) -> Tuple[str, str]:
+ """Registers a user and creates a pusher for them.
+
+ Args:
+ username: the localpart of the new user's Matrix ID.
+ enabled: whether to create the pusher in an enabled or disabled state.
+ """
user_id = self.register_user(username, "pass")
access_token = self.login(username, "pass")
# Register the pusher
+ self._set_pusher(user_id, access_token, enabled)
+
+ return user_id, access_token
+
+ def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
+ """Creates or updates the pusher for the given user.
+
+ Args:
+ user_id: the user's Matrix ID.
+ access_token: the access token associated with the pusher.
+ enabled: whether to enable or disable the pusher.
+ """
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -749,11 +771,11 @@ class HTTPPusherTests(HomeserverTestCase):
pushkey="a@example.com",
lang=None,
data={"url": "http://example.com/_matrix/push/v1/notify"},
+ enabled=enabled,
+ device_id=user_tuple.device_id,
)
)
- return user_id, access_token
-
def test_dont_notify_rule_overrides_message(self) -> None:
"""
The override push rule will suppress notification
@@ -791,3 +813,148 @@ class HTTPPusherTests(HomeserverTestCase):
# The user sends a message back (sends a notification)
self.helper.send(room, body="Hello", tok=access_token)
self.assertEqual(len(self.push_attempts), 1)
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_disable(self) -> None:
+ """Tests that disabling a pusher means it's not pushed to anymore."""
+ user_id, access_token = self._make_user_with_pusher("user")
+ other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+ room = self.helper.create_room_as(user_id, tok=access_token)
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # Send a message and check that it generated a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Disable the pusher.
+ self._set_pusher(user_id, access_token, enabled=False)
+
+ # Send another message and check that it did not generate a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Get the pushers for the user and check that it is marked as disabled.
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+ enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+ self.assertFalse(enabled)
+ self.assertTrue(isinstance(enabled, bool))
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_enable(self) -> None:
+ """Tests that enabling a disabled pusher means it gets pushed to."""
+ # Create the user with the pusher already disabled.
+ user_id, access_token = self._make_user_with_pusher("user", enabled=False)
+ other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+ room = self.helper.create_room_as(user_id, tok=access_token)
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # Send a message and check that it did not generate a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 0)
+
+ # Enable the pusher.
+ self._set_pusher(user_id, access_token, enabled=True)
+
+ # Send another message and check that it did generate a push.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 1)
+
+ # Get the pushers for the user and check that it is marked as enabled.
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+ enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+ self.assertTrue(enabled)
+ self.assertTrue(isinstance(enabled, bool))
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_null_enabled(self) -> None:
+ """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
+ created before the column was introduced) is considered enabled.
+ """
+ # We intentionally set 'enabled' to None so that it's stored as NULL in the
+ # database.
+ user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type]
+
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+ self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
+
+ def test_update_different_device_access_token_device_id(self) -> None:
+ """Tests that if we create a pusher from one device, the update it from another
+ device, the access token and device ID associated with the pusher stays the
+ same.
+ """
+ # Create a user with a pusher.
+ user_id, access_token = self._make_user_with_pusher("user")
+
+ # Get the token ID for the current access token, since that's what we store in
+ # the pushers table. Also get the device ID from it.
+ user_tuple = self.get_success(
+ self.hs.get_datastores().main.get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+ device_id = user_tuple.device_id
+
+ # Generate a new access token, and update the pusher with it.
+ new_token = self.login("user", "pass")
+ self._set_pusher(user_id, new_token, enabled=False)
+
+ # Get the current list of pushers for the user.
+ ret = self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
+ pushers: List[PusherConfig] = list(ret)
+
+ # Check that we still have one pusher, and that the access token and device ID
+ # associated with it didn't change.
+ self.assertEqual(len(pushers), 1)
+ self.assertEqual(pushers[0].access_token, token_id)
+ self.assertEqual(pushers[0].device_id, device_id)
+
+ @override_config({"experimental_features": {"msc3881_enabled": True}})
+ def test_device_id(self) -> None:
+ """Tests that a pusher created with a given device ID shows that device ID in
+ GET /pushers requests.
+ """
+ self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # We create the pusher with an HTTP request rather than with
+ # _make_user_with_pusher so that we can test the device ID is correctly set when
+ # creating a pusher via an API call.
+ self.make_request(
+ method="POST",
+ path="/pushers/set",
+ content={
+ "kind": "http",
+ "app_id": "m.http",
+ "app_display_name": "HTTP Push Notifications",
+ "device_display_name": "pushy push",
+ "pushkey": "a@example.com",
+ "lang": "en",
+ "data": {"url": "http://example.com/_matrix/push/v1/notify"},
+ },
+ access_token=access_token,
+ )
+
+ # Look up the user info for the access token so we can compare the device ID.
+ lookup_result: TokenLookupResult = self.get_success(
+ self.hs.get_datastores().main.get_user_by_access_token(access_token)
+ )
+
+ # Get the user's devices and check it has the correct device ID.
+ channel = self.make_request("GET", "/pushers", access_token=access_token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["pushers"]), 1)
+ self.assertEqual(
+ channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
+ lookup_result.device_id,
+ )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 49a21e2e85..efd92793c0 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -171,7 +171,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if send_receipt:
self.get_success(
self.master_store.insert_receipt(
- ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
+ ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
)
)
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index eb00117845..ede6d0c118 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastores().main.insert_receipt(
- "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
+ "!room:blue",
+ "m.read",
+ USER_ID,
+ ["$event:blue"],
+ thread_id=None,
+ data={"a": 1},
)
)
self.replicate()
@@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual("$event:blue", row.event_id)
+ self.assertIsNone(row.thread_id)
self.assertEqual({"a": 1}, row.data)
# Now let's disconnect and insert some data.
@@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.get_success(
self.hs.get_datastores().main.insert_receipt(
- "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ "!room2:blue",
+ "m.read",
+ USER_ID,
+ ["$event2:foo"],
+ thread_id=None,
+ data={"a": 2},
)
)
self.replicate()
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
new file mode 100644
index 0000000000..b93cae67d3
--- /dev/null
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,79 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import synapse
+from synapse.module_api import cached
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+FIRST_VALUE = "one"
+SECOND_VALUE = "two"
+
+KEY = "mykey"
+
+
+class TestCache:
+ current_value = FIRST_VALUE
+
+ @cached()
+ async def cached_function(self, user_id: str) -> str:
+ return self.current_value
+
+
+class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ ]
+
+ def test_module_cache_full_invalidation(self):
+ main_cache = TestCache()
+ self.hs.get_module_api().register_cached_function(main_cache.cached_function)
+
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
+
+ worker_cache = TestCache()
+ worker_hs.get_module_api().register_cached_function(
+ worker_cache.cached_function
+ )
+
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(
+ FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+ )
+
+ main_cache.current_value = SECOND_VALUE
+ worker_cache.current_value = SECOND_VALUE
+ # No invalidation yet, should return the cached value on both the main process and the worker
+ self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
+ self.assertEqual(
+ FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
+ )
+
+ # Full invalidation on the main process, should be replicated on the worker that
+ # should returned the updated value too
+ self.get_success(
+ self.hs.get_module_api().invalidate_cache(
+ main_cache.cached_function, (KEY,)
+ )
+ )
+
+ self.assertEqual(
+ SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
+ )
+ self.assertEqual(
+ SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
+ )
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 8f4f6688ce..59fea93e49 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
token_id = user_dict.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9f536ceeb3..1847e6ad6b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
- self.hs.get_pusherpool().add_pusher(
+ self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.other_user,
access_token=token_id,
kind="http",
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
new file mode 100644
index 0000000000..d5bb16c98d
--- /dev/null
+++ b/tests/rest/client/test_login_token_request.py
@@ -0,0 +1,132 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, login_token_request
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ login_token_request.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+ self.hs.config.registration.enable_registration = True
+ self.hs.config.registration.registrations_require_3pid = []
+ self.hs.config.registration.auto_join_rooms = []
+ self.hs.config.captcha.enable_registration_captcha = False
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = "user123"
+ self.password = "password"
+
+ def test_disabled(self) -> None:
+ channel = self.make_request("POST", "/login/token", {}, access_token=None)
+ self.assertEqual(channel.code, 400)
+
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", "/login/token", {}, access_token=token)
+ self.assertEqual(channel.code, 400)
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_require_auth(self) -> None:
+ channel = self.make_request("POST", "/login/token", {}, access_token=None)
+ self.assertEqual(channel.code, 401)
+
+ @override_config({"experimental_features": {"msc3882_enabled": True}})
+ def test_uia_on(self) -> None:
+ user_id = self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", "/login/token", {}, access_token=token)
+ self.assertEqual(channel.code, 401)
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+ session = channel.json_body["session"]
+
+ uia = {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": self.user},
+ "password": self.password,
+ "session": session,
+ },
+ }
+
+ channel = self.make_request("POST", "/login/token", uia, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 300)
+
+ login_token = channel.json_body["login_token"]
+
+ channel = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ @override_config(
+ {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+ )
+ def test_uia_off(self) -> None:
+ user_id = self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", "/login/token", {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 300)
+
+ login_token = channel.json_body["login_token"]
+
+ channel = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3882_enabled": True,
+ "msc3882_ui_auth": False,
+ "msc3882_token_timeout": "15s",
+ }
+ }
+ )
+ def test_expires_in(self) -> None:
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ channel = self.make_request("POST", "/login/token", {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in"], 15)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 651f4f415d..d33e34d829 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel.json_body["chunk"][0],
)
+ @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_repeated_paginate_relations(self) -> None:
"""Test that if we paginate using a limit and tokens then we get the
expected events.
@@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
- f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
+ # Test forward pagination.
+ prev_token = ""
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEqual(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEqual(found_event_ids, expected_event_ids)
+
def test_pagination_from_sync_and_messages(self) -> None:
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index a6679e1312..85739c464e 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -12,25 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple, Union
+import datetime
+from typing import Dict, List, Tuple, Union
import attr
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.events import _EventInternalMetadata
-from synapse.util import json_encoder
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict
+from synapse.util import Clock, json_encoder
import tests.unittest
import tests.utils
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class _BackfillSetupInfo:
+ room_id: str
+ depth_map: Dict[str, int]
+
+
class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
def test_get_prev_events_for_room(self):
@@ -571,11 +584,471 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(count, 1)
- _, event_id = self.get_success(
+ next_staged_event_info = self.get_success(
self.store.get_next_staged_event_id_for_room(room_id)
)
+ assert next_staged_event_info
+ _, event_id = next_staged_event_info
self.assertEqual(event_id, "$fake_event_id_500")
+ def _setup_room_for_backfill_tests(self) -> _BackfillSetupInfo:
+ """
+ Sets up a room with various events and backward extremities to test
+ backfill functions against.
+
+ Returns:
+ _BackfillSetupInfo including the `room_id` to test against and
+ `depth_map` of events in the room
+ """
+ room_id = "!backfill-room-test:some-host"
+
+ # The silly graph we use to test grabbing backward extremities,
+ # where the top is the oldest events.
+ # 1 (oldest)
+ # |
+ # 2 ⹁
+ # | \
+ # | [b1, b2, b3]
+ # | |
+ # | A
+ # | /
+ # 3 {
+ # | \
+ # | [b4, b5, b6]
+ # | |
+ # | B
+ # | /
+ # 4 ´
+ # |
+ # 5 (newest)
+
+ event_graph: Dict[str, List[str]] = {
+ "1": [],
+ "2": ["1"],
+ "3": ["2", "A"],
+ "4": ["3", "B"],
+ "5": ["4"],
+ "A": ["b1", "b2", "b3"],
+ "b1": ["2"],
+ "b2": ["2"],
+ "b3": ["2"],
+ "B": ["b4", "b5", "b6"],
+ "b4": ["3"],
+ "b5": ["3"],
+ "b6": ["3"],
+ }
+
+ depth_map: Dict[str, int] = {
+ "1": 1,
+ "2": 2,
+ "b1": 3,
+ "b2": 3,
+ "b3": 3,
+ "A": 4,
+ "3": 5,
+ "b4": 6,
+ "b5": 6,
+ "b6": 6,
+ "B": 7,
+ "4": 8,
+ "5": 9,
+ }
+
+ # The events we have persisted on our server.
+ # The rest are events in the room but not backfilled tet.
+ our_server_events = {"5", "4", "B", "3", "A"}
+
+ complete_event_dict_map: Dict[str, JsonDict] = {}
+ stream_ordering = 0
+ for (event_id, prev_event_ids) in event_graph.items():
+ depth = depth_map[event_id]
+
+ complete_event_dict_map[event_id] = {
+ "event_id": event_id,
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": "@sender",
+ "prev_event_ids": prev_event_ids,
+ "auth_event_ids": [],
+ "origin_server_ts": stream_ordering,
+ "depth": depth,
+ "stream_ordering": stream_ordering,
+ "content": {"body": "event" + event_id},
+ }
+
+ stream_ordering += 1
+
+ def populate_db(txn: LoggingTransaction):
+ # Insert the room to satisfy the foreign key constraint of
+ # `event_failed_pull_attempts`
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ },
+ )
+
+ # Insert our server events
+ for event_id in our_server_events:
+ event_dict = complete_event_dict_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_dict.get("event_id"),
+ "type": event_dict.get("type"),
+ "room_id": event_dict.get("room_id"),
+ "depth": event_dict.get("depth"),
+ "topological_ordering": event_dict.get("depth"),
+ "stream_ordering": event_dict.get("stream_ordering"),
+ "processed": True,
+ "outlier": False,
+ },
+ )
+
+ # Insert the event edges
+ for event_id in our_server_events:
+ for prev_event_id in event_graph[event_id]:
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event_id,
+ "prev_event_id": prev_event_id,
+ "room_id": room_id,
+ },
+ )
+
+ # Insert the backward extremities
+ prev_events_of_our_events = {
+ prev_event_id
+ for our_server_event in our_server_events
+ for prev_event_id in complete_event_dict_map[our_server_event][
+ "prev_event_ids"
+ ]
+ }
+ backward_extremities = prev_events_of_our_events - our_server_events
+ for backward_extremity in backward_extremities:
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="event_backward_extremities",
+ values={
+ "event_id": backward_extremity,
+ "room_id": room_id,
+ },
+ )
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_setup_room_for_backfill_tests_populate_db",
+ populate_db,
+ )
+ )
+
+ return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
+
+ def test_get_backfill_points_in_room(self):
+ """
+ Test to make sure we get some backfill points
+ """
+ setup_info = self._setup_room_for_backfill_tests()
+ room_id = setup_info.room_id
+
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertListEqual(
+ backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
+ )
+
+ def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
+ self,
+ ):
+ """
+ Test to make sure that events we have attempted to backfill (and within
+ backoff timeout duration) do not show up as an event to backfill again.
+ """
+ setup_info = self._setup_room_for_backfill_tests()
+ room_id = setup_info.room_id
+
+ # Record some attempts to backfill these events which will make
+ # `get_backfill_points_in_room` exclude them because we
+ # haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b5", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b4", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b2", "fake cause")
+ )
+
+ # No time has passed since we attempted to backfill ^
+
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ # Only the backfill points that we didn't record earlier exist here.
+ self.assertListEqual(backfill_event_ids, ["b6", "2", "b1"])
+
+ def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
+ self,
+ ):
+ """
+ Test to make sure after we fake attempt to backfill event "b3" many times,
+ we can see retry and see the "b3" again after the backoff timeout duration
+ has exceeded.
+ """
+ setup_info = self._setup_room_for_backfill_tests()
+ room_id = setup_info.room_id
+
+ # Record some attempts to backfill these events which will make
+ # `get_backfill_points_in_room` exclude them because we
+ # haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b3", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(room_id, "b1", "fake cause")
+ )
+
+ # Now advance time by 2 hours and we should only be able to see "b3"
+ # because we have waited long enough for the single attempt (2^1 hours)
+ # but we still shouldn't see "b1" because we haven't waited long enough
+ # for this many attempts. We didn't do anything to "b2" so it should be
+ # visible regardless.
+ self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+ # Make sure that "b1" is not in the list because we've
+ # already attempted many times
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertListEqual(backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2"])
+
+ # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
+ # see if we can now backfill it
+ self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
+
+ # Try again after we advanced enough time and we should see "b3" again
+ backfill_points = self.get_success(
+ self.store.get_backfill_points_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertListEqual(
+ backfill_event_ids, ["b6", "b5", "b4", "2", "b3", "b2", "b1"]
+ )
+
+ def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
+ """
+ Sets up a room with various insertion event backward extremities to test
+ backfill functions against.
+
+ Returns:
+ _BackfillSetupInfo including the `room_id` to test against and
+ `depth_map` of events in the room
+ """
+ room_id = "!backfill-room-test:some-host"
+
+ depth_map: Dict[str, int] = {
+ "1": 1,
+ "2": 2,
+ "insertion_eventA": 3,
+ "3": 4,
+ "insertion_eventB": 5,
+ "4": 6,
+ "5": 7,
+ }
+
+ def populate_db(txn: LoggingTransaction):
+ # Insert the room to satisfy the foreign key constraint of
+ # `event_failed_pull_attempts`
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ },
+ )
+
+ # Insert our server events
+ stream_ordering = 0
+ for event_id, depth in depth_map.items():
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "type": EventTypes.MSC2716_INSERTION
+ if event_id.startswith("insertion_event")
+ else "test_regular_type",
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "stream_ordering": stream_ordering,
+ "processed": True,
+ "outlier": False,
+ },
+ )
+
+ if event_id.startswith("insertion_event"):
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="insertion_event_extremities",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ },
+ )
+
+ stream_ordering += 1
+
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_setup_room_for_insertion_backfill_tests_populate_db",
+ populate_db,
+ )
+ )
+
+ return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
+
+ def test_get_insertion_event_backward_extremities_in_room(self):
+ """
+ Test to make sure insertion event backward extremities are returned.
+ """
+ setup_info = self._setup_room_for_insertion_backfill_tests()
+ room_id = setup_info.room_id
+
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertListEqual(
+ backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
+ )
+
+ def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
+ self,
+ ):
+ """
+ Test to make sure that insertion events we have attempted to backfill
+ (and within backoff timeout duration) do not show up as an event to
+ backfill again.
+ """
+ setup_info = self._setup_room_for_insertion_backfill_tests()
+ room_id = setup_info.room_id
+
+ # Record some attempts to backfill these events which will make
+ # `get_insertion_event_backward_extremities_in_room` exclude them
+ # because we haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+
+ # No time has passed since we attempted to backfill ^
+
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ # Only the backfill points that we didn't record earlier exist here.
+ self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
+
+ def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
+ self,
+ ):
+ """
+ Test to make sure after we fake attempt to backfill event
+ "insertion_eventA" many times, we can see retry and see the
+ "insertion_eventA" again after the backoff timeout duration has
+ exceeded.
+ """
+ setup_info = self._setup_room_for_insertion_backfill_tests()
+ room_id = setup_info.room_id
+
+ # Record some attempts to backfill these events which will make
+ # `get_backfill_points_in_room` exclude them because we
+ # haven't passed the backoff interval.
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventB", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+ self.get_success(
+ self.store.record_event_failed_pull_attempt(
+ room_id, "insertion_eventA", "fake cause"
+ )
+ )
+
+ # Now advance time by 2 hours and we should only be able to see
+ # "insertion_eventB" because we have waited long enough for the single
+ # attempt (2^1 hours) but we still shouldn't see "insertion_eventA"
+ # because we haven't waited long enough for this many attempts.
+ self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
+
+ # Make sure that "insertion_eventA" is not in the list because we've
+ # already attempted many times
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertListEqual(backfill_event_ids, ["insertion_eventB"])
+
+ # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
+ # see if we can now backfill it
+ self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
+
+ # Try at "insertion_eventA" again after we advanced enough time and we
+ # should see "insertion_eventA" again
+ backfill_points = self.get_success(
+ self.store.get_insertion_event_backward_extremities_in_room(room_id)
+ )
+ backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
+ self.assertListEqual(
+ backfill_event_ids, ["insertion_eventB", "insertion_eventA"]
+ )
+
@attr.s
class FakeEvent:
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index fc43d7edd1..473c965e19 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Tuple
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
@@ -22,8 +24,6 @@ from synapse.util import Clock
from tests.unittest import HomeserverTestCase
-USER_ID = "@user:example.com"
-
class EventPushActionsStoreTestCase(HomeserverTestCase):
servlets = [
@@ -38,21 +38,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
assert persist_events_store is not None
self.persist_events_store = persist_events_store
- def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None:
- self.get_success(
- self.store.get_unread_push_actions_for_user_in_range_for_http(
- USER_ID, 0, 1000, 20
- )
- )
+ def _create_users_and_room(self) -> Tuple[str, str, str, str, str]:
+ """
+ Creates two users and a shared room.
- def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None:
- self.get_success(
- self.store.get_unread_push_actions_for_user_in_range_for_email(
- USER_ID, 0, 1000, 20
- )
- )
-
- def test_count_aggregation(self) -> None:
+ Returns:
+ Tuple of (user 1 ID, user 1 token, user 2 ID, user 2 token, room ID).
+ """
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
@@ -65,6 +57,70 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)
+ return user_id, token, other_id, other_token, room_id
+
+ def test_get_unread_push_actions_for_user_in_range(self) -> None:
+ """Test getting unread push actions for HTTP and email pushers."""
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+
+ # Create two events, one of which is a highlight.
+ self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content={"msgtype": "m.text", "body": "msg"},
+ tok=other_token,
+ )
+ event_id = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content={"msgtype": "m.text", "body": user_id},
+ tok=other_token,
+ )["event_id"]
+
+ # Fetch unread actions for HTTP pushers.
+ http_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual(2, len(http_actions))
+
+ # Fetch unread actions for email pushers.
+ email_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual(2, len(email_actions))
+
+ # Send a receipt, which should clear any actions.
+ self.get_success(
+ self.store.insert_receipt(
+ room_id,
+ "m.read",
+ user_id=user_id,
+ event_ids=[event_id],
+ thread_id=None,
+ data={},
+ )
+ )
+ http_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_http(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual([], http_actions)
+ email_actions = self.get_success(
+ self.store.get_unread_push_actions_for_user_in_range_for_email(
+ user_id, 0, 1000, 20
+ )
+ )
+ self.assertEqual([], email_actions)
+
+ def test_count_aggregation(self) -> None:
+ # Create a user to receive notifications and send receipts.
+ user_id, token, _, other_token, room_id = self._create_users_and_room()
+
last_event_id: str
def _assert_counts(noitf_count: int, highlight_count: int) -> None:
@@ -106,6 +162,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
"m.read",
user_id=user_id,
event_ids=[event_id],
+ thread_id=None,
data={},
)
)
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index c89bfff241..9459ee1705 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -131,13 +131,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event1_2_id],
+ None,
+ {},
)
)
@@ -164,7 +169,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
@@ -180,7 +185,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+ self.room_id2,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event2_1_id],
+ None,
+ {},
)
)
res = self.get_success(
@@ -202,13 +212,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event1_2_id],
+ None,
+ {},
)
)
@@ -241,7 +256,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating
self.get_success(
self.store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
@@ -259,7 +274,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+ self.room_id2,
+ ReceiptTypes.READ_PRIVATE,
+ OUR_USER_ID,
+ [event2_1_id],
+ None,
+ {},
)
)
res = self.get_success(
diff --git a/tests/unittest.py b/tests/unittest.py
index 975b0a23a7..00cb023198 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
assert self.helper.auth_user_id is not None
+ token = "some_fake_token"
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastores().main.add_access_token_to_user(
self.helper.auth_user_id,
- "some_fake_token",
+ token,
None,
None,
)
)
- async def get_user_by_access_token(
- token: Optional[str] = None, allow_guest: bool = False
- ) -> JsonDict:
- assert self.helper.auth_user_id is not None
- return {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": token_id,
- "is_guest": False,
- }
-
- async def get_user_by_req(
- request: SynapseRequest,
- allow_guest: bool = False,
- allow_expired: bool = False,
- ) -> Requester:
+ # This has to be a function and not just a Mock, because
+ # `self.helper.auth_user_id` is temporarily reassigned in some tests
+ async def get_requester(*args, **kwargs) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
- UserID.from_string(self.helper.auth_user_id),
- token_id,
- False,
- False,
- None,
+ user_id=UserID.from_string(self.helper.auth_user_id),
+ access_token_id=token_id,
)
# Type ignore: mypy doesn't like us assigning to methods.
- self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
- self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
- self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
- return_value="1234"
- )
+ self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment]
+ self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment]
+ self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment]
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
|