From 2cc5ea933dbe65445e3711bb3f05022b007029ea Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 24 Feb 2022 17:55:45 +0000 Subject: Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services. (#11617) Co-authored-by: Erik Johnston --- tests/appservice/test_scheduler.py | 55 ++++++++--- tests/handlers/test_appservice.py | 194 ++++++++++++++++++++++++++++++++++++- tests/storage/test_appservice.py | 8 +- 3 files changed, 233 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 8fb6687f89..b9dc4dfe1b 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -68,6 +68,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -92,6 +94,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events=events, ephemeral=[], to_device_messages=[], # txn made and saved + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -114,7 +118,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[], to_device_messages=[] + service=service, + events=events, + ephemeral=[], + to_device_messages=[], + one_time_key_counts={}, + unused_fallback_keys={}, ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -216,7 +225,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -231,11 +240,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], []) + self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event2, event3], [], [], None, None + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -261,15 +272,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -288,13 +299,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) + self.txn_ctrl.send.assert_called_with( + service, [event_list[0]], [], [], None, None + ) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[1:101], [], [], None, None + ) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) + self.txn_ctrl.send.assert_called_with( + service, event_list[101:], [], [], None, None + ) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): @@ -302,14 +319,18 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4, name="service") event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], event_list, [], None, None + ) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() @@ -324,13 +345,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [] + service, [], event_list_2 + event_list_3, [], None, None ) self.assertEquals(2, self.txn_ctrl.send.call_count) @@ -343,7 +364,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) + self.txn_ctrl.send.assert_called_once_with( + service, [], first_chunk, [], None, None + ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9918ff6807..6e0ec37963 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.storage -from synapse.appservice import ApplicationService +from synapse.appservice import ( + ApplicationService, + TransactionOneTimeKeyCounts, + TransactionUnusedFallbackKeys, +) from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.rest.client import login, receipts, room, sendtodevice +from synapse.rest.client import login, receipts, register, room, sendtodevice +from synapse.server import HomeServer from synapse.types import RoomStreamToken +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import make_awaitable, simple_async_mock +from tests.unittest import override_config from tests.utils import MockClock @@ -428,7 +436,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # # The uninterested application service should not have been notified at all. self.send_mock.assert_called_once() - service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent self.assertEqual(service, interested_appservice) @@ -540,7 +555,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages = call[0] + service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -582,3 +597,174 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self._services.append(appservice) return appservice + + +class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): + # Argument indices for pulling out arguments from a `send_mock`. + ARG_OTK_COUNTS = 4 + ARG_FALLBACK_KEYS = 5 + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track what's going out + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + + # Define an application service for the tests + self._service_token = "VERYSECRET" + self._service = ApplicationService( + self._service_token, + "as1.invalid", + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastores().main.services_cache = [self._service] + + # Register some appservice users + self._sender_user, self._sender_device = self.register_appservice_user( + "as.sender", self._service_token + ) + self._namespaced_user, self._namespaced_device = self.register_appservice_user( + "_as_user1", self._service_token + ) + + # Register a real user as well. + self._real_user = self.register_user("real.user", "meow") + self._real_user_token = self.login("real.user", "meow") + + async def _add_otks_for_device( + self, user_id: str, device_id: str, otk_count: int + ) -> None: + """ + Add some dummy keys. It doesn't matter if they're not a real algorithm; + that should be opaque to the server anyway. + """ + await self.hs.get_datastores().main.add_e2e_one_time_keys( + user_id, + device_id, + self.clock.time_msec(), + [("algo", f"k{i}", "{}") for i in range(otk_count)], + ) + + async def _add_fallback_key_for_device( + self, user_id: str, device_id: str, used: bool + ) -> None: + """ + Adds a fake fallback key to a device, optionally marking it as used + right away. + """ + store = self.hs.get_datastores().main + await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"}) + if used is True: + # Mark the key as used + await store.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": "algo", + "key_id": "fk", + }, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", + ) + + def _set_up_devices_and_a_room(self) -> str: + """ + Helper to set up devices for all the users + and a room for the users to talk in. + """ + + async def preparation(): + await self._add_otks_for_device(self._sender_user, self._sender_device, 42) + await self._add_fallback_key_for_device( + self._sender_user, self._sender_device, used=True + ) + await self._add_otks_for_device( + self._namespaced_user, self._namespaced_device, 36 + ) + await self._add_fallback_key_for_device( + self._namespaced_user, self._namespaced_device, used=False + ) + + # Register a device for the real user, too, so that we can later ensure + # that we don't leak information to the AS about the non-AS user. + await self.hs.get_datastores().main.store_device( + self._real_user, "REALDEV", "UltraMatrix 3000" + ) + await self._add_otks_for_device(self._real_user, "REALDEV", 50) + + self.get_success(preparation()) + + room_id = self.helper.create_room_as( + self._real_user, is_public=True, tok=self._real_user_token + ) + self.helper.join( + room_id, + self._namespaced_user, + tok=self._service_token, + appservice_user_id=self._namespaced_user, + ) + + # Check it was called for sanity. (This was to send the join event to the AS.) + self.send_mock.assert_called() + self.send_mock.reset_mock() + + return room_id + + @override_config( + {"experimental_features": {"msc3202_transaction_extensions": True}} + ) + def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus( + self, + ) -> None: + """ + Tests that: + - the AS receives one-time key counts and unused fallback keys for: + - the specified sender; and + - any user who is in receipt of the PDUs + """ + + room_id = self._set_up_devices_and_a_room() + + # Send a message into the AS's room + self.helper.send(room_id, "woof woof", tok=self._real_user_token) + + # Capture what was sent as an AS transaction. + self.send_mock.assert_called() + last_args, _last_kwargs = self.send_mock.call_args + otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS] + unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[ + self.ARG_FALLBACK_KEYS + ] + + self.assertEqual( + otks, + { + "@as.sender:test": {self._sender_device: {"algo": 42}}, + "@_as_user1:test": {self._namespaced_device: {"algo": 36}}, + }, + ) + self.assertEqual( + unused_fallbacks, + { + "@as.sender:test": {self._sender_device: []}, + "@_as_user1:test": {self._namespaced_device: ["algo"]}, + }, + ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 50703ccaee..d2f654214e 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -267,7 +267,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) ) self.assertEquals(txn.id, 1) @@ -283,7 +283,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -296,7 +296,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -320,7 +320,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) txn = self.get_success( - self.store.create_appservice_txn(service, events, [], []) + self.store.create_appservice_txn(service, events, [], [], {}, {}) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) -- cgit 1.4.1