From 034db2ba2115d935ce62b641b4051e477a454eac Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 26 Sep 2019 11:47:53 +0100 Subject: Fix dummy event insertion consent bug (#6053) Fixes #5905 --- synapse/handlers/message.py | 99 ++++++++++++++++++++++++++++++++------------- 1 file changed, 72 insertions(+), 27 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 1f8272784e..0f8cce8ffe 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -222,6 +222,13 @@ class MessageHandler(object): } +# The duration (in ms) after which rooms should be removed +# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try +# to generate a dummy event for them once more) +# +_DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000 + + class EventCreationHandler(object): def __init__(self, hs): self.hs = hs @@ -258,6 +265,13 @@ class EventCreationHandler(object): self.config.block_events_without_consent_error ) + # Rooms which should be excluded from dummy insertion. (For instance, + # those without local users who can send events into the room). + # + # map from room id to time-of-last-attempt. + # + self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int] + # we need to construct a ConsentURIBuilder here, as it checks that the necessary # config options, but *only* if we have a configuration for which we are # going to need it. @@ -888,9 +902,11 @@ class EventCreationHandler(object): """Background task to send dummy events into rooms that have a large number of extremities """ - + self._expire_rooms_to_exclude_from_dummy_event_insertion() room_ids = yield self.store.get_rooms_with_many_extremities( - min_count=10, limit=5 + min_count=10, + limit=5, + room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), ) for room_id in room_ids: @@ -904,32 +920,61 @@ class EventCreationHandler(object): members = yield self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids ) + dummy_event_sent = False + for user_id in members: + if not self.hs.is_mine_id(user_id): + continue + requester = create_requester(user_id) + try: + event, context = yield self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) - user_id = None - for member in members: - if self.hs.is_mine_id(member): - user_id = member - break - - if not user_id: - # We don't have a joined user. - # TODO: We should do something here to stop the room from - # appearing next time. - continue + event.internal_metadata.proactively_send = False - requester = create_requester(user_id) + yield self.send_nonmember_event( + requester, event, context, ratelimit=False + ) + dummy_event_sent = True + break + except ConsentNotGivenError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of consent. Will try another user" % (room_id, user_id) + ) + except AuthError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of power. Will try another user" % (room_id, user_id) + ) - event, context = yield self.create_event( - requester, - { - "type": "org.matrix.dummy_event", - "content": {}, - "room_id": room_id, - "sender": user_id, - }, - prev_events_and_hashes=prev_events_and_hashes, + if not dummy_event_sent: + # Did not find a valid user in the room, so remove from future attempts + # Exclusion is time limited, so the room will be rechecked in the future + # dependent on _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY + logger.info( + "Failed to send dummy event into room %s. Will exclude it from " + "future attempts until cache expires" % (room_id,) + ) + now = self.clock.time_msec() + self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now + + def _expire_rooms_to_exclude_from_dummy_event_insertion(self): + expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY + to_expire = set() + for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items(): + if time < expire_before: + to_expire.add(room_id) + for room_id in to_expire: + logger.debug( + "Expiring room id %s from dummy event insertion exclusion cache", + room_id, ) - - event.internal_metadata.proactively_send = False - - yield self.send_nonmember_event(requester, event, context, ratelimit=False) + del self._rooms_to_exclude_from_dummy_event_insertion[room_id] -- cgit 1.5.1 From 3ca4c7c516a349cbb9dace0ace33bb889955eda1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 12:02:36 +0100 Subject: Use new EventPersistenceStore --- synapse/handlers/federation.py | 3 ++- synapse/handlers/message.py | 3 ++- synapse/server.py | 9 ++++++++- tests/replication/slave/storage/_base.py | 1 + tests/replication/slave/storage/test_events.py | 10 +++++++--- tests/storage/test_redaction.py | 11 ++++++----- tests/storage/test_room.py | 3 ++- tests/storage/test_roommember.py | 3 ++- tests/storage/test_state.py | 3 ++- tests/test_visibility.py | 7 ++++--- tests/utils.py | 10 ++++++++-- 11 files changed, 44 insertions(+), 19 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4b4c6c15f9..7e9c9aee13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -109,6 +109,7 @@ class FederationHandler(BaseHandler): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -2648,7 +2649,7 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) else: - max_stream_id = yield self.store.persist_events( + max_stream_id = yield self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0f8cce8ffe..7908a2d52c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -234,6 +234,7 @@ class EventCreationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -868,7 +869,7 @@ class EventCreationHandler(object): if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - (event_stream_id, max_stream_id) = yield self.store.persist_event( + event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( event, context=context ) diff --git a/synapse/server.py b/synapse/server.py index 1fcc7375d3..54a7f4aa5f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -95,6 +95,7 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler +from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -196,6 +197,7 @@ class HomeServer(object): "account_validity_handler", "saml_handler", "event_client_serializer", + "storage", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -225,6 +227,7 @@ class HomeServer(object): self.registration_ratelimiter = Ratelimiter() self.datastore = None + self.datastores = None # Other kwargs are explicit dependencies for depname in kwargs: @@ -234,6 +237,7 @@ class HomeServer(object): logger.info("Setting up.") with self.get_db_conn() as conn: self.datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(self.datastore, conn, self) conn.commit() logger.info("Finished setting up.") @@ -266,7 +270,7 @@ class HomeServer(object): return self.clock def get_datastore(self): - return self.datastore + return self.datastores.main def get_config(self): return self.config @@ -537,6 +541,9 @@ class HomeServer(object): def build_event_client_serializer(self): return EventClientSerializer(self) + def build_storage(self) -> Storage: + return Storage(self, self.datastores) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 104349cdbd..4f924ce451 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() + self.storage = hs.get_storage() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index a368117b43..b68e9fe082 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) + self.get_success( + self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + ) self.replicate() event_source = RoomEventSource(self.hs) @@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.master_store.persist_events([(event, context)], backfilled=True) + self.storage.persistence.persist_events( + [(event, context)], backfilled=True + ) ) else: - self.get_success(self.master_store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 427d3c49c5..4561c3e383 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.store.persist_event(event_1, context_1)) + self.get_success(self.storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.store.persist_event(event_2, context_2)) + self.get_success(self.storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1bee45706f..3ddaa151fe 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.store.persist_event( + yield self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 447a3c6ffb..9ddd17f73d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5c2cf3c2db..d573a3e07b 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -34,6 +34,7 @@ class StateStoreTestCase(tests.unittest.TestCase): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -63,7 +64,7 @@ class StateStoreTestCase(tests.unittest.TestCase): builder ) - yield self.store.persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 18f1a0035d..6ae1ea9b04 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -36,6 +36,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") @@ -137,7 +138,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield self.event_creation_handler.create_new_client_event( builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -159,7 +160,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks @@ -180,7 +181,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): builder ) - yield self.hs.get_datastore().persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index 0a64f75d04..2808eea933 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -326,10 +326,16 @@ def setup_test_homeserver( if homeserverToUse.__name__ == "TestHomeServer": hs.setup_master() else: + # If we have been given an explicit datastore we probably want to mock + # out the DataStores somehow too. This all feels a bit wrong, but then + # mocking the stores feels wrong too. + datastores = Mock(datastore=datastore) + hs = homeserverToUse( name, db_pool=None, datastore=datastore, + datastores=datastores, config=config, version_string="Synapse/tests", database_engine=db_engine, @@ -647,7 +653,7 @@ def create_room(hs, room_id, creator_id): creator_id (str) """ - store = hs.get_datastore() + persistence_store = hs.get_storage().persistence event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() @@ -664,4 +670,4 @@ def create_room(hs, room_id, creator_id): event, context = yield event_creation_handler.create_new_client_event(builder) - yield store.persist_event(event, context) + yield persistence_store.persist_event(event, context) -- cgit 1.5.1 From 69f0054ce675bd9d35104c39af9fae9a908b7f33 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 17:25:54 +0100 Subject: Port to use state storage --- synapse/handlers/admin.py | 7 +- synapse/handlers/device.py | 3 +- synapse/handlers/events.py | 6 +- synapse/handlers/federation.py | 19 ++--- synapse/handlers/initial_sync.py | 14 ++-- synapse/handlers/message.py | 10 +-- synapse/handlers/pagination.py | 6 +- synapse/handlers/room.py | 6 +- synapse/handlers/search.py | 12 ++-- synapse/handlers/sync.py | 20 +++--- synapse/notifier.py | 6 +- synapse/push/httppusher.py | 3 +- synapse/push/mailer.py | 3 +- synapse/push/push_tools.py | 9 +-- synapse/state/__init__.py | 13 ++-- synapse/visibility.py | 30 ++++---- tests/storage/test_state.py | 150 ++++++++++++++++++++++++++------------- tests/test_state.py | 3 + tests/test_visibility.py | 11 ++- 19 files changed, 216 insertions(+), 115 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1a87b58838..6407d56f8e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,6 +30,9 @@ class AdminHandler(BaseHandler): def __init__(self, hs): super(AdminHandler, self).__init__(hs) + self.storage = hs.get_storage() + self.state_store = self.storage.state + @defer.inlineCallbacks def get_whois(self, user): connections = [] @@ -205,7 +208,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.store, user_id, events) + events = yield filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -241,7 +244,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.store.get_state_for_event(event_id) + state = yield self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..b3fd7e6249 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler): self.hs = hs self.state = hs.get_state_handler() + self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() @trace @@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler): continue # mapping from event_id -> state_dict - prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) + prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5e748687e3..45fe13c62f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler): class EventHandler(BaseHandler): + def __init__(self, hs): + super(EventHandler, self).__init__(hs) + self.storage = hs.get_storage() + @defer.inlineCallbacks def get_event(self, user, room_id, event_id): """Retrieve a single specified event. @@ -172,7 +176,7 @@ class EventHandler(BaseHandler): is_peeking = user.to_string() not in users filtered = yield filter_events_for_client( - self.store, user.to_string(), [event], is_peeking=is_peeking + self.storage, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 08276fdebf..4d9e33346d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -110,6 +110,7 @@ class FederationHandler(BaseHandler): self.store = hs.get_datastore() self.storage = hs.get_storage() + self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -325,7 +326,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.store.get_state_groups_ids(room_id, seen) + ours = yield self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -889,7 +890,7 @@ class FederationHandler(BaseHandler): # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = yield filter_events_for_server( - self.store, + self.storage, self.server_name, list(extremities_events.values()), redact=False, @@ -1550,7 +1551,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups(room_id, [event_id]) if state_groups: _, state = list(iteritems(state_groups)).pop() @@ -1579,7 +1580,7 @@ class FederationHandler(BaseHandler): event_id, allow_none=False, check_room_id=room_id ) - state_groups = yield self.store.get_state_groups_ids(room_id, [event_id]) + state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id]) if state_groups: _, state = list(state_groups.items()).pop() @@ -1607,7 +1608,7 @@ class FederationHandler(BaseHandler): events = yield self.store.get_backfill_events(room_id, pdu_list, limit) - events = yield filter_events_for_server(self.store, origin, events) + events = yield filter_events_for_server(self.storage, origin, events) return events @@ -1637,7 +1638,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - events = yield filter_events_for_server(self.store, origin, [event]) + events = yield filter_events_for_server(self.storage, origin, [event]) event = events[0] return event else: @@ -1903,7 +1904,7 @@ class FederationHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets = yield self.store.get_state_groups( + state_sets = yield self.state_store.get_state_groups( event.room_id, extrem_ids ) state_sets = list(state_sets.values()) @@ -1994,7 +1995,7 @@ class FederationHandler(BaseHandler): ) missing_events = yield filter_events_for_server( - self.store, origin, missing_events + self.storage, origin, missing_events ) return missing_events @@ -2235,7 +2236,7 @@ class FederationHandler(BaseHandler): # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index f991efeee3..49c9e031f9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler): self.validator = EventValidator() self.snapshot_cache = SnapshotCache() self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state def snapshot_all_rooms( self, @@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler): elif event.membership == Membership.LEAVE: room_end_token = "s%d" % (event.stream_ordering,) deferred_room_state = run_in_background( - self.store.get_state_for_events, [event.event_id] + self.state_store.get_state_for_events, [event.event_id] ) deferred_room_state.addCallback( lambda states: states[event.event_id] @@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client(self.store, user_id, messages) + messages = yield filter_events_for_client( + self.storage, user_id, messages + ) start_token = now_token.copy_and_replace("room_key", token) end_token = now_token.copy_and_replace("room_key", room_end_token) @@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler): def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.store.get_state_for_events([member_event_id]) + room_state = yield self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace("room_key", token) @@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler): ) messages = yield filter_events_for_client( - self.store, user_id, messages, is_peeking=is_peeking + self.storage, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace("room_key", token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7908a2d52c..6e2a360262 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,6 +59,8 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks @@ -82,7 +84,7 @@ class MessageHandler(object): data = yield self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -135,12 +137,12 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.store, user_id, last_events + self.storage, user_id, last_events ) event = last_events[0] if visible_events: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -161,7 +163,7 @@ class MessageHandler(object): ) room_state = yield self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.store.get_state_for_events( + room_state = yield self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5744f4579d..b7185fe7a0 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -69,6 +69,8 @@ class PaginationHandler(object): self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state self.clock = hs.get_clock() self._server_name = hs.hostname @@ -255,7 +257,7 @@ class PaginationHandler(object): events = event_filter.filter(events) events = yield filter_events_for_client( - self.store, user_id, events, is_peeking=(member_event_id is None) + self.storage, user_id, events, is_peeking=(member_event_id is None) ) if not events: @@ -274,7 +276,7 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2816bd8f87..84bad39815 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -822,6 +822,8 @@ class RoomContextHandler(object): def __init__(self, hs): self.hs = hs self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, event_filter): @@ -848,7 +850,7 @@ class RoomContextHandler(object): def filter_evts(events): return filter_events_for_client( - self.store, user.to_string(), events, is_peeking=is_peeking + self.storage, user.to_string(), events, is_peeking=is_peeking ) event = yield self.store.get_event( @@ -890,7 +892,7 @@ class RoomContextHandler(object): # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = yield self.store.get_state_for_events( + state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) results["state"] = list(state[last_event_id].values()) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd5e90bacb..f4d8a60774 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -35,6 +35,8 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) self._event_serializer = hs.get_event_client_serializer() + self.storage = hs.get_storage() + self.state_store = self.storage.state @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -221,7 +223,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -271,7 +273,7 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) events = yield filter_events_for_client( - self.store, user.to_string(), filtered_events + self.storage, user.to_string(), filtered_events ) room_events.extend(events) @@ -340,11 +342,11 @@ class SearchHandler(BaseHandler): ) res["events_before"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_before"] + self.storage, user.to_string(), res["events_before"] ) res["events_after"] = yield filter_events_for_client( - self.store, user.to_string(), res["events_after"] + self.storage, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( @@ -372,7 +374,7 @@ class SearchHandler(BaseHandler): [(EventTypes.Member, sender) for sender in senders] ) - state = yield self.store.get_state_for_event( + state = yield self.state_store.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index d99160e9d7..43a082dcda 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -230,6 +230,8 @@ class SyncHandler(object): self.response_cache = ResponseCache(hs, "sync") self.state = hs.get_state_handler() self.auth = hs.get_auth() + self.storage = hs.get_storage() + self.state_store = self.storage.state # ExpiringCache((User, Device)) -> LruCache(state_key => event_id) self.lazy_loaded_members_cache = ExpiringCache( @@ -417,7 +419,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -470,7 +472,7 @@ class SyncHandler(object): current_state_ids = frozenset(itervalues(current_state_ids)) loaded_recents = yield filter_events_for_client( - self.store, + self.storage, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -509,7 +511,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -580,7 +582,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -757,11 +759,11 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) @@ -781,7 +783,7 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.store.get_state_ids_for_event( + state_at_timeline_start = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: @@ -810,7 +812,7 @@ class SyncHandler(object): ) if batch: - current_state_ids = yield self.store.get_state_ids_for_event( + current_state_ids = yield self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: @@ -841,7 +843,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.store.get_state_ids_for_event( + state_ids = yield self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( diff --git a/synapse/notifier.py b/synapse/notifier.py index 4e091314e6..af161a81d7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,7 @@ class Notifier(object): self.room_to_user_streams = {} self.hs = hs + self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] @@ -425,7 +426,10 @@ class Notifier(object): if name == "room": new_events = yield filter_events_for_client( - self.store, user.to_string(), new_events, is_peeking=is_peeking + self.storage, + user.to_string(), + new_events, + is_peeking=is_peeking, ) elif name == "presence": now = self.clock.time_msec() diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..36e26032a1 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -64,6 +64,7 @@ class HttpPusher(object): def __init__(self, hs, pusherdict): self.hs = hs self.store = self.hs.get_datastore() + self.storage = self.hs.get_storage() self.clock = self.hs.get_clock() self.state_handler = self.hs.get_state_handler() self.user_id = pusherdict["user_name"] @@ -329,7 +330,7 @@ class HttpPusher(object): return d ctx = yield push_tools.get_context_for_event( - self.store, self.state_handler, event, self.user_id + self.storage, self.state_handler, event, self.user_id ) d = { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5b16ab4ae8..1d15a06a58 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -119,6 +119,7 @@ class Mailer(object): self.store = self.hs.get_datastore() self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() + self.storage = hs.get_storage() self.app_name = app_name logger.info("Created Mailer for app_name %s" % app_name) @@ -389,7 +390,7 @@ class Mailer(object): } the_events = yield filter_events_for_client( - self.store, user_id, results["events_before"] + self.storage, user_id, results["events_before"] ) the_events.append(notif_event) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a54051a726..de5c101a58 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,6 +16,7 @@ from twisted.internet import defer from synapse.push.presentable_names import calculate_room_name, name_from_member_event +from synapse.storage import Storage @defer.inlineCallbacks @@ -43,22 +44,22 @@ def get_badge_count(store, user_id): @defer.inlineCallbacks -def get_context_for_event(store, state_handler, ev, user_id): +def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield store.get_state_ids_for_event(ev.event_id) + room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room name = yield calculate_room_name( - store, room_state_ids, user_id, fallback_to_single_member=False + storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield store.get_event(sender_state_event_id) + sender_state_event = yield storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dc9f5a9008..4e91eb66fe 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -103,6 +103,7 @@ class StateHandler(object): def __init__(self, hs): self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() @@ -271,7 +272,7 @@ class StateHandler(object): else: current_state_ids = prev_state_ids - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=None, @@ -321,7 +322,7 @@ class StateHandler(object): delta_ids = dict(entry.delta_ids) delta_ids[key] = event.event_id - state_group = yield self.store.store_state_group( + state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -334,7 +335,7 @@ class StateHandler(object): delta_ids = entry.delta_ids if entry.state_group is None: - entry.state_group = yield self.store.store_state_group( + entry.state_group = yield self.state_store.store_state_group( event.event_id, event.room_id, prev_group=entry.prev_group, @@ -376,14 +377,16 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids) + state_groups_ids = yield self.state_store.get_state_groups_ids( + room_id, event_ids + ) if len(state_groups_ids) == 0: return _StateCacheEntry(state={}, state_group=None) elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.store.get_state_group_delta(name) + prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, diff --git a/synapse/visibility.py b/synapse/visibility.py index bf0f1eebd8..8c843febd8 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event +from synapse.storage import Storage from synapse.storage.state import StateFilter from synapse.types import get_domain_from_id @@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - store, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() ): """ Check which events a user is allowed to see Args: - store (synapse.storage.DataStore): our datastore (can also be a worker - store) + storage user_id(str): user id to be checked events(list[synapse.events.EventBase]): sequence of events to be checked is_peeking(bool): should be True if: @@ -68,12 +68,12 @@ def filter_events_for_client( events = list(e for e in events if not e.internal_metadata.is_soft_failed()) types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield store.get_state_for_events( + event_id_to_state = yield storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield store.get_global_account_data_by_type_for_user( + ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -84,7 +84,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) def allowed(event): """ @@ -213,13 +213,17 @@ def filter_events_for_client( @defer.inlineCallbacks def filter_events_for_server( - store, server_name, events, redact=True, check_history_visibility_only=False + storage: Storage, + server_name, + events, + redact=True, + check_history_visibility_only=False, ): """Filter a list of events based on whether given server is allowed to see them. Args: - store (DataStore) + storage server_name (str) events (iterable[FrozenEvent]) redact (bool): Whether to return a redacted version of the event, or @@ -274,7 +278,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If thats the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -292,14 +296,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield store.get_events(visibility_ids) + event_map = yield storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in itervalues(event_map) ) if not check_history_visibility_only: - erased_senders = yield store.are_users_erased((e.sender for e in events)) + erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -328,7 +332,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield store.get_state_ids_for_events( + event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -358,7 +362,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield store.get_events( + event_map = yield storage.main.get_events( [ e_id for e_id, key in iteritems(event_id_to_state_key) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d573a3e07b..43200654f1 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -35,6 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() self.storage = hs.get_storage() + self.state_datastore = self.store self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -83,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids( + state_group_map = yield self.storage.state.get_state_groups_ids( self.room, [e2.event_id] ) self.assertEqual(len(state_group_map), 1) @@ -102,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) + state_group_map = yield self.storage.state.get_state_groups( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -142,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event(e5.event_id) + state = yield self.storage.state.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -158,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -182,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, @@ -200,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -216,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group_ids = yield self.storage.state.get_state_groups_ids( + room_id, [e5.event_id] + ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -238,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -251,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -268,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -288,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -305,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -318,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -332,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) @@ -347,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase): ) state_dict_ids.pop((e2.type, e2.state_key)) - self.store._state_group_cache.invalidate(group) - self.store._state_group_cache.update( - sequence=self.store._state_group_cache.sequence, + self.state_datastore._state_group_cache.invalidate(group) + self.state_datastore._state_group_cache.update( + sequence=self.state_datastore._state_group_cache.sequence, key=group, value=state_dict_ids, # list fetched keys so it knows it's partial fetched_keys=((e1.type, e1.state_key),), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) @@ -370,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -382,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -395,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -406,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -425,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -436,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -449,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -460,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False diff --git a/tests/test_state.py b/tests/test_state.py index 610ec9fb46..38246555bd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -158,10 +158,12 @@ class Graph(object): class StateTestCase(unittest.TestCase): def setUp(self): self.store = StateGroupStore() + storage = Mock(main=self.store, state=self.store) hs = Mock( spec_set=[ "config", "get_datastore", + "get_storage", "get_auth", "get_state_handler", "get_clock", @@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) + hs.get_storage.return_value = storage self.state = StateHandler(hs) self.event_id = 0 diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 6ae1ea9b04..f7381b2885 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -14,6 +14,8 @@ # limitations under the License. import logging +from mock import Mock + from twisted.internet import defer from twisted.internet.defer import succeed @@ -63,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) # the result should be 5 redacted events, and 5 unredacted events. @@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # ... and the filtering happens. filtered = yield filter_events_for_server( - self.store, "test_server", events_to_filter + self.storage, "test_server", events_to_filter ) for i in range(0, len(events_to_filter)): @@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): logger.info("Starting filtering") start = time.time() + + storage = Mock() + storage.main = test_store + storage.state = test_store + filtered = yield filter_events_for_server( test_store, "test_server", events_to_filter ) -- cgit 1.5.1 From 54fef094b31e0401d6d35efdf7d5d6b0b9e5d51f Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 31 Oct 2019 10:23:24 +0000 Subject: Remove usage of deprecated logger.warn method from codebase (#6271) Replace every instance of `logger.warn` with `logger.warning` as the former is deprecated. --- changelog.d/6271.misc | 1 + scripts/move_remote_media_to_new_store.py | 2 +- scripts/synapse_port_db | 6 ++-- synapse/api/auth.py | 2 +- synapse/app/__init__.py | 4 ++- synapse/app/appservice.py | 4 +-- synapse/app/client_reader.py | 4 +-- synapse/app/event_creator.py | 4 +-- synapse/app/federation_reader.py | 4 +-- synapse/app/federation_sender.py | 4 +-- synapse/app/frontend_proxy.py | 4 +-- synapse/app/homeserver.py | 6 ++-- synapse/app/media_repository.py | 4 +-- synapse/app/pusher.py | 4 +-- synapse/app/synchrotron.py | 4 +-- synapse/app/user_dir.py | 4 +-- synapse/config/key.py | 4 +-- synapse/config/logger.py | 2 +- synapse/event_auth.py | 2 +- synapse/federation/federation_base.py | 6 ++-- synapse/federation/federation_client.py | 8 +++-- synapse/federation/federation_server.py | 20 ++++++------ synapse/federation/sender/transaction_manager.py | 4 +-- synapse/federation/transport/server.py | 8 +++-- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 2 +- synapse/handlers/auth.py | 6 ++-- synapse/handlers/device.py | 4 +-- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/federation.py | 36 ++++++++++++---------- synapse/handlers/groups_local.py | 2 +- synapse/handlers/identity.py | 6 ++-- synapse/handlers/message.py | 2 +- synapse/handlers/profile.py | 2 +- synapse/handlers/room.py | 2 +- synapse/http/client.py | 4 +-- synapse/http/federation/srv_resolver.py | 2 +- synapse/http/matrixfederationclient.py | 10 +++--- synapse/http/request_metrics.py | 2 +- synapse/http/server.py | 2 +- synapse/http/servlet.py | 4 +-- synapse/http/site.py | 4 +-- synapse/logging/context.py | 2 +- synapse/push/httppusher.py | 4 +-- synapse/push/push_rule_evaluator.py | 4 +-- synapse/replication/http/_base.py | 2 +- synapse/replication/http/membership.py | 2 +- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/protocol.py | 2 +- synapse/rest/admin/__init__.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/client/v2_alpha/account.py | 14 ++++----- synapse/rest/client/v2_alpha/register.py | 10 +++--- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/media/v1/media_repository.py | 12 +++++--- synapse/rest/media/v1/preview_url_resource.py | 16 +++++----- synapse/rest/media/v1/thumbnail_resource.py | 4 +-- .../resource_limits_server_notices.py | 2 +- synapse/storage/_base.py | 6 ++-- synapse/storage/data_stores/main/pusher.py | 2 +- synapse/storage/data_stores/main/search.py | 2 +- synapse/util/async_helpers.py | 2 +- synapse/util/caches/__init__.py | 2 +- synapse/util/metrics.py | 6 ++-- synapse/util/rlimit.py | 2 +- 65 files changed, 164 insertions(+), 149 deletions(-) create mode 100644 changelog.d/6271.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6271.misc b/changelog.d/6271.misc new file mode 100644 index 0000000000..2369760272 --- /dev/null +++ b/changelog.d/6271.misc @@ -0,0 +1 @@ +Replace every instance of `logger.warn` method with `logger.warning` as the former is deprecated. \ No newline at end of file diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py index 12747c6024..b5b63933ab 100755 --- a/scripts/move_remote_media_to_new_store.py +++ b/scripts/move_remote_media_to_new_store.py @@ -72,7 +72,7 @@ def move_media(origin_server, file_id, src_paths, dest_paths): # check that the original exists original_file = src_paths.remote_media_filepath(origin_server, file_id) if not os.path.exists(original_file): - logger.warn( + logger.warning( "Original for %s/%s (%s) does not exist", origin_server, file_id, diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 54faed1e83..0d3321682c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -157,7 +157,7 @@ class Store( ) except self.database_engine.module.DatabaseError as e: if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) + logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) if i < N: i += 1 conn.rollback() @@ -432,7 +432,7 @@ class Porter(object): for row in rows: d = dict(zip(headers, row)) if "\0" in d['value']: - logger.warn('dropping search row %s', d) + logger.warning('dropping search row %s', d) else: rows_dict.append(d) @@ -647,7 +647,7 @@ class Porter(object): if isinstance(col, bytes): return bytearray(col) elif isinstance(col, string_types) and "\0" in col: - logger.warn( + logger.warning( "DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 53f3bb0fa8..5d0b7d2801 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -497,7 +497,7 @@ class Auth(object): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: - logger.warn("Unrecognised appservice access token.") + logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender return defer.succeed(service) diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index d877c77834..a01bac2997 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -44,6 +44,8 @@ def check_bind_error(e, address, bind_addresses): bind_addresses (list): Addresses on which the service listens. """ if address == "0.0.0.0" and "::" in bind_addresses: - logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]") + logger.warning( + "Failed to listen on 0.0.0.0, continuing because listening on [::]" + ) else: raise e diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 767b87d2db..02b900f382 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -94,7 +94,7 @@ class AppserviceServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -103,7 +103,7 @@ class AppserviceServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dbcc414c42..dadb487d5f 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -153,7 +153,7 @@ class ClientReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -162,7 +162,7 @@ class ClientReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index f20d810ece..d110599a35 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -147,7 +147,7 @@ class EventCreatorServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -156,7 +156,7 @@ class EventCreatorServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 1ef027a88c..418c086254 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -132,7 +132,7 @@ class FederationReaderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -141,7 +141,7 @@ class FederationReaderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 04fbb407af..139221ad34 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -123,7 +123,7 @@ class FederationSenderServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -132,7 +132,7 @@ class FederationSenderServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index 9504bfbc70..e647459d0e 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -204,7 +204,7 @@ class FrontendProxyServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -213,7 +213,7 @@ class FrontendProxyServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index eb54f56853..8997c1f9e7 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -282,7 +282,7 @@ class SynapseHomeServer(HomeServer): reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -291,7 +291,7 @@ class SynapseHomeServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) def run_startup_checks(self, db_conn, database_engine): all_users_native = are_all_users_on_domain( @@ -569,7 +569,7 @@ def run(hs): hs.config.report_stats_endpoint, stats ) except Exception as e: - logger.warn("Error reporting stats: %s", e) + logger.warning("Error reporting stats: %s", e) def performance_stats_init(): try: diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 6bc7202f33..2c6dd3ef02 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -120,7 +120,7 @@ class MediaRepositoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -129,7 +129,7 @@ class MediaRepositoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index d84732ee3c..01a5ffc363 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -114,7 +114,7 @@ class PusherServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -123,7 +123,7 @@ class PusherServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 6a7e2fa707..b14da09f47 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -326,7 +326,7 @@ class SynchrotronServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -335,7 +335,7 @@ class SynchrotronServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index a5d6dc7915..6cb100319f 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -150,7 +150,7 @@ class UserDirectoryServer(HomeServer): ) elif listener["type"] == "metrics": if not self.get_config().enable_metrics: - logger.warn( + logger.warning( ( "Metrics listener configured, but " "enable_metrics is not True!" @@ -159,7 +159,7 @@ class UserDirectoryServer(HomeServer): else: _base.listen_metrics(listener["bind_addresses"], listener["port"]) else: - logger.warn("Unrecognized listener type: %s", listener["type"]) + logger.warning("Unrecognized listener type: %s", listener["type"]) self.get_tcp_replication().start_replication(self) diff --git a/synapse/config/key.py b/synapse/config/key.py index ec5d430afb..52ff1b2621 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -125,7 +125,7 @@ class KeyConfig(Config): # if neither trusted_key_servers nor perspectives are given, use the default. if "perspectives" not in config and "trusted_key_servers" not in config: - logger.warn(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) + logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN) key_servers = [{"server_name": "matrix.org"}] else: key_servers = config.get("trusted_key_servers", []) @@ -156,7 +156,7 @@ class KeyConfig(Config): if not self.macaroon_secret_key: # Unfortunately, there are people out there that don't have this # set. Lets just be "nice" and derive one from their secret key. - logger.warn("Config is missing macaroon_secret_key") + logger.warning("Config is missing macaroon_secret_key") seed = bytes(self.signing_key[0]) self.macaroon_secret_key = hashlib.sha256(seed).digest() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index be92e33f93..2d2c1e54df 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -182,7 +182,7 @@ def _reload_stdlib_logging(*args, log_config=None): logger = logging.getLogger("") if not log_config: - logger.warn("Reloaded a blank config?") + logger.warning("Reloaded a blank config?") logging.config.dictConfig(log_config) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e7b722547b..ec3243b27b 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -77,7 +77,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if auth_events is None: # Oh, we don't know what the state of the room was, so we # are trusting that this is allowed (at least for now) - logger.warn("Trusting event: %s", event.event_id) + logger.warning("Trusting event: %s", event.event_id) return if event.type == EventTypes.Create: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 223aace0d9..0e22183280 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -102,7 +102,7 @@ class FederationBase(object): pass if not res: - logger.warn( + logger.warning( "Failed to find copy of %s with valid signature", pdu.event_id ) @@ -173,7 +173,7 @@ class FederationBase(object): return redacted_event if self.spam_checker.check_event_for_spam(pdu): - logger.warn( + logger.warning( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json(), @@ -185,7 +185,7 @@ class FederationBase(object): def errback(failure, pdu): failure.trap(SynapseError) with PreserveLoggingContext(ctx): - logger.warn( + logger.warning( "Signature check failed for %s: %s", pdu.event_id, failure.getErrorMessage(), diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f5c1632916..595706d01a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -522,12 +522,12 @@ class FederationClient(FederationBase): res = yield callback(destination) return res except InvalidResponseError as e: - logger.warn("Failed to %s via %s: %s", description, destination, e) + logger.warning("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: - logger.warn( + logger.warning( "Failed to %s via %s: %i %s", description, destination, @@ -535,7 +535,9 @@ class FederationClient(FederationBase): e.args[0], ) except Exception: - logger.warn("Failed to %s via %s", description, destination, exc_info=1) + logger.warning( + "Failed to %s via %s", description, destination, exc_info=1 + ) raise SynapseError(502, "Failed to %s via any server" % (description,)) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d5a19764d2..d942d77a72 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -220,7 +220,7 @@ class FederationServer(FederationBase): try: await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn("Ignoring PDUs for room %s from banned server", room_id) + logger.warning("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -233,7 +233,7 @@ class FederationServer(FederationBase): await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) + logger.warning("Error handling PDU %s: %s", event_id, e) pdu_results[event_id] = {"error": str(e)} except Exception as e: f = failure.Failure() @@ -333,7 +333,9 @@ class FederationServer(FederationBase): room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: - logger.warn("Room version %s not in %s", room_version, supported_versions) + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) raise IncompatibleRoomVersionError(room_version=room_version) pdu = await self.handler.on_make_join_request(origin, room_id, user_id) @@ -679,7 +681,7 @@ def server_matches_acl_event(server_name, acl_event): # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warn("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignorning non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -693,7 +695,7 @@ def server_matches_acl_event(server_name, acl_event): # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warn("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignorning non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -703,7 +705,7 @@ def server_matches_acl_event(server_name, acl_event): # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warn("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignorning non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): @@ -717,7 +719,7 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn( + logger.warning( "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) ) return False @@ -772,7 +774,7 @@ class FederationHandlerRegistry(object): async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: - logger.warn("No handler registered for EDU type %s", edu_type) + logger.warning("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: @@ -785,7 +787,7 @@ class FederationHandlerRegistry(object): def on_query(self, query_type, args): handler = self.query_handlers.get(query_type) if not handler: - logger.warn("No handler registered for query type %s", query_type) + logger.warning("No handler registered for query type %s", query_type) raise NotFoundError("No handler for Query type '%s'" % (query_type,)) return handler(args) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 5b6c79c51a..67b3e1ab6e 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -146,7 +146,7 @@ class TransactionManager(object): if code == 200: for e_id, r in response.get("pdus", {}).items(): if "error" in r: - logger.warn( + logger.warning( "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, @@ -155,7 +155,7 @@ class TransactionManager(object): ) else: for p in pdus: - logger.warn( + logger.warning( "TX [%s] {%s} Failed to send event %s", destination, txn_id, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 0f16f21c2d..d6c23f22bd 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -202,7 +202,7 @@ def _parse_auth_header(header_bytes): sig = strip_quotes(param_dict["sig"]) return origin, key, sig except Exception as e: - logger.warn( + logger.warning( "Error parsing auth header '%s': %s", header_bytes.decode("ascii", "replace"), e, @@ -287,10 +287,12 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.warn("authenticate_request failed: missing authentication") + logger.warning( + "authenticate_request failed: missing authentication" + ) raise except Exception as e: - logger.warn("authenticate_request failed: %s", e) + logger.warning("authenticate_request failed: %s", e) raise request_tags = { diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dfd7ae041b..d950a8b246 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -181,7 +181,7 @@ class GroupAttestionRenewer(object): elif not self.is_mine_id(user_id): destination = get_domain_from_id(user_id) else: - logger.warn( + logger.warning( "Incorrectly trying to do attestations for user: %r in %r", user_id, group_id, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 8f10b6adbb..29e8ffc295 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -488,7 +488,7 @@ class GroupsServerHandler(object): profile = yield self.profile_handler.get_profile_from_cache(user_id) user_profile.update(profile) except Exception as e: - logger.warn("Error getting profile for %s: %s", user_id, e) + logger.warning("Error getting profile for %s: %s", user_id, e) user_profiles.append(user_profile) return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 333eb30625..7a0f54ca24 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -525,7 +525,7 @@ class AuthHandler(BaseHandler): result = None if not user_infos: - logger.warn("Attempted to login as %s but they do not exist", user_id) + logger.warning("Attempted to login as %s but they do not exist", user_id) elif len(user_infos) == 1: # a single match (possibly not exact) result = user_infos.popitem() @@ -534,7 +534,7 @@ class AuthHandler(BaseHandler): result = (user_id, user_infos[user_id]) else: # multiple matches, none of them exact - logger.warn( + logger.warning( "Attempted to login as %s but it matches more than one user " "inexactly: %r", user_id, @@ -728,7 +728,7 @@ class AuthHandler(BaseHandler): result = yield self.validate_hash(password, password_hash) if not result: - logger.warn("Failed password login for user %s", user_id) + logger.warning("Failed password login for user %s", user_id) return None return user_id diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5f23ee4488..befef2cf3d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -656,7 +656,7 @@ class DeviceListUpdater(object): except (NotRetryingDestination, RequestSendFailed, HttpResponseException): # TODO: Remember that we are now out of sync and try again # later - logger.warn("Failed to handle device list update for %s", user_id) + logger.warning("Failed to handle device list update for %s", user_id) # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -694,7 +694,7 @@ class DeviceListUpdater(object): # up on storing the total list of devices and only handle the # delta instead. if len(devices) > 1000: - logger.warn( + logger.warning( "Ignoring device list snapshot for %s as it has >1K devs (%d)", user_id, len(devices), diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 0043cbea17..73b9e120f5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -52,7 +52,7 @@ class DeviceMessageHandler(object): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): - logger.warn( + logger.warning( "Dropping device message from %r with spoofed sender %r", origin, sender_user_id, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 08276fdebf..f1547e3039 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -181,7 +181,7 @@ class FederationHandler(BaseHandler): try: self._sanity_check_event(pdu) except SynapseError as err: - logger.warn( + logger.warning( "[%s %s] Received event failed sanity checks", room_id, event_id ) raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) @@ -302,7 +302,7 @@ class FederationHandler(BaseHandler): # following. if sent_to_us_directly: - logger.warn( + logger.warning( "[%s %s] Rejecting: failed to fetch %d prev events: %s", room_id, event_id, @@ -406,7 +406,7 @@ class FederationHandler(BaseHandler): state = [event_map[e] for e in six.itervalues(state_map)] auth_chain = list(auth_chains) except Exception: - logger.warn( + logger.warning( "[%s %s] Error attempting to resolve state at missing " "prev_events", room_id, @@ -519,7 +519,9 @@ class FederationHandler(BaseHandler): # We failed to get the missing events, but since we need to handle # the case of `get_missing_events` not returning the necessary # events anyway, it is safe to simply log the error and continue. - logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e) + logger.warning( + "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e + ) return logger.info( @@ -546,7 +548,7 @@ class FederationHandler(BaseHandler): yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: - logger.warn( + logger.warning( "[%s %s] Received prev_event %s failed history check.", room_id, event_id, @@ -1060,7 +1062,7 @@ class FederationHandler(BaseHandler): SynapseError if the event does not pass muster """ if len(ev.prev_event_ids()) > 20: - logger.warn( + logger.warning( "Rejecting event %s which has %i prev_events", ev.event_id, len(ev.prev_event_ids()), @@ -1068,7 +1070,7 @@ class FederationHandler(BaseHandler): raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: - logger.warn( + logger.warning( "Rejecting event %s which has %i auth_events", ev.event_id, len(ev.auth_event_ids()), @@ -1204,7 +1206,7 @@ class FederationHandler(BaseHandler): with nested_logging_context(p.event_id): yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: - logger.warn( + logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e ) @@ -1251,7 +1253,7 @@ class FederationHandler(BaseHandler): builder=builder ) except AuthError as e: - logger.warn("Failed to create join to %s because %s", room_id, e) + logger.warning("Failed to create join to %s because %s", room_id, e) raise e event_allowed = yield self.third_party_event_rules.check_event_allowed( @@ -1495,7 +1497,7 @@ class FederationHandler(BaseHandler): room_version, event, context, do_sig_check=False ) except AuthError as e: - logger.warn("Failed to create new leave %r because %s", event, e) + logger.warning("Failed to create new leave %r because %s", event, e) raise e return event @@ -1789,7 +1791,7 @@ class FederationHandler(BaseHandler): # cause SynapseErrors in auth.check. We don't want to give up # the attempt to federate altogether in such cases. - logger.warn("Rejecting %s because %s", e.event_id, err.msg) + logger.warning("Rejecting %s because %s", e.event_id, err.msg) if e == event: raise @@ -1845,7 +1847,9 @@ class FederationHandler(BaseHandler): try: yield self.do_auth(origin, event, context, auth_events=auth_events) except AuthError as e: - logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg) + logger.warning( + "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg + ) context.rejected = RejectedReason.AUTH_ERROR @@ -1939,7 +1943,7 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: - logger.warn("Soft-failing %r because %s", event, e) + logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @defer.inlineCallbacks @@ -2038,7 +2042,7 @@ class FederationHandler(BaseHandler): try: event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: - logger.warn("Failed auth resolution for %r because %s", event, e) + logger.warning("Failed auth resolution for %r because %s", event, e) raise e @defer.inlineCallbacks @@ -2432,7 +2436,7 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying new third party invite %r because %s", event, e) + logger.warning("Denying new third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) @@ -2488,7 +2492,7 @@ class FederationHandler(BaseHandler): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as e: - logger.warn("Denying third party invite %r because %s", event, e) + logger.warning("Denying third party invite %r because %s", event, e) raise e yield self._check_signature(event, context) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 46eb9ee88b..92fecbfc44 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -392,7 +392,7 @@ class GroupsLocalHandler(object): try: user_profile = yield self.profile_handler.get_profile(user_id) except Exception as e: - logger.warn("No profile for user %s: %s", user_id, e) + logger.warning("No profile for user %s: %s", user_id, e) user_profile = {} return {"state": "invite", "user_profile": user_profile} diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index ba99ddf76d..000fbf090f 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -272,7 +272,7 @@ class IdentityHandler(BaseHandler): changed = False if e.code in (400, 404, 501): # The remote server probably doesn't support unbinding (yet) - logger.warn("Received %d response while unbinding threepid", e.code) + logger.warning("Received %d response while unbinding threepid", e.code) else: logger.error("Failed to unbind threepid on identity server: %s", e) raise SynapseError(500, "Failed to contact identity server") @@ -403,7 +403,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " @@ -457,7 +457,7 @@ class IdentityHandler(BaseHandler): if self.hs.config.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use - logger.warn( + logger.warning( 'The config option "trust_identity_server_for_password_resets" ' 'has been replaced by "account_threepid_delegate". ' "Please consult the sample config at docs/sample_config.yaml for " diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7908a2d52c..5698e5fee0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -688,7 +688,7 @@ class EventCreationHandler(object): try: yield self.auth.check_from_context(room_version, event, context) except AuthError as err: - logger.warn("Denying new event %r because %s", event, err) + logger.warning("Denying new event %r because %s", event, err) raise err # Ensure that we can round trip before trying to persist in db diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 8690f69d45..22e0a04da4 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -275,7 +275,7 @@ class BaseProfileHandler(BaseHandler): ratelimit=False, # Try to hide that these events aren't atomic. ) except Exception as e: - logger.warn( + logger.warning( "Failed to update join event for room %s - %s", room_id, str(e) ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2816bd8f87..445a08f445 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -922,7 +922,7 @@ class RoomEventSource(object): from_token = RoomStreamToken.parse(from_key) if from_token.topological: - logger.warn("Stream has topological part!!!! %r", from_key) + logger.warning("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) app_service = self.store.get_app_service_by_user_id(user.to_string()) diff --git a/synapse/http/client.py b/synapse/http/client.py index cdf828a4ff..2df5b383b5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -535,7 +535,7 @@ class SimpleHttpClient(object): b"Content-Length" in resp_headers and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) raise SynapseError( 502, "Requested file is too large > %r bytes" % (self.max_size,), @@ -543,7 +543,7 @@ class SimpleHttpClient(object): ) if response.code > 299: - logger.warn("Got %d when downloading %s" % (response.code, url)) + logger.warning("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) # TODO: if our Content-Type is HTML or something, just read the first diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 3fe4ffb9e5..021b233a7d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -148,7 +148,7 @@ class SrvResolver(object): # Try something in the cache, else rereaise cache_entry = self._cache.get(service_name, None) if cache_entry: - logger.warn( + logger.warning( "Failed to resolve %r, falling back to cache. %r", service_name, e ) return list(cache_entry) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 3f7c93ffcb..691380abda 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -149,7 +149,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): body = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, @@ -457,7 +457,7 @@ class MatrixFederationHttpClient(object): except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. - logger.warn( + logger.warning( "{%s} [%s] Failed to get error response: %s %s: %s", request.txn_id, request.destination, @@ -478,7 +478,7 @@ class MatrixFederationHttpClient(object): break except RequestSendFailed as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -513,7 +513,7 @@ class MatrixFederationHttpClient(object): raise except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Request failed: %s %s: %s", request.txn_id, request.destination, @@ -889,7 +889,7 @@ class MatrixFederationHttpClient(object): d.addTimeout(self.default_timeout, self.reactor) length = yield make_deferred_yieldable(d) except Exception as e: - logger.warn( + logger.warning( "{%s} [%s] Error reading response: %s", request.txn_id, request.destination, diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 46af27c8f6..58f9cc61c8 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -170,7 +170,7 @@ class RequestMetrics(object): tag = context.tag if context != self.start_context: - logger.warn( + logger.warning( "Context have unexpectedly changed %r, %r", context, self.start_context, diff --git a/synapse/http/server.py b/synapse/http/server.py index 2ccb210fd6..943d12c907 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -454,7 +454,7 @@ def respond_with_json( # the Deferred fires, but since the flag is RIGHT THERE it seems like # a waste. if request._disconnected: - logger.warn( + logger.warning( "Not sending response to request %s, already disconnected.", request ) return diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 274c1a6a87..e9a5e46ced 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -219,13 +219,13 @@ def parse_json_value_from_request(request, allow_empty_body=False): try: content_unicode = content_bytes.decode("utf8") except UnicodeDecodeError: - logger.warn("Unable to decode UTF-8") + logger.warning("Unable to decode UTF-8") raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) try: content = json.loads(content_unicode) except Exception as e: - logger.warn("Unable to parse JSON: %s", e) + logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) return content diff --git a/synapse/http/site.py b/synapse/http/site.py index df5274c177..ff8184a3d0 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -199,7 +199,7 @@ class SynapseRequest(Request): # It's useful to log it here so that we can get an idea of when # the client disconnects. with PreserveLoggingContext(self.logcontext): - logger.warn( + logger.warning( "Error processing request %r: %s %s", self, reason.type, reason.value ) @@ -305,7 +305,7 @@ class SynapseRequest(Request): try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: - logger.warn("Failed to stop metrics: %r", e) + logger.warning("Failed to stop metrics: %r", e) class XForwardedForRequest(SynapseRequest): diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 370000e377..2c1fb9ddac 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -294,7 +294,7 @@ class LoggingContext(object): """Enters this logging context into thread local storage""" old_context = self.set_current_context(self) if self.previous_context != old_context: - logger.warn( + logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 6299587808..23d3678420 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -246,7 +246,7 @@ class HttpPusher(object): # we really only give up so that if the URL gets # fixed, we don't suddenly deliver a load # of old notifications. - logger.warn( + logger.warning( "Giving up on a notification to user %s, " "pushkey %s", self.user_id, self.pushkey, @@ -299,7 +299,7 @@ class HttpPusher(object): if pk != self.pushkey: # for sanity, we only remove the pushkey if it # was the one we actually sent... - logger.warn( + logger.warning( ("Ignoring rejected pushkey %s because we" " didn't send it"), pk, ) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 5ed9147de4..b1587183a8 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -117,7 +117,7 @@ class PushRuleEvaluatorForEvent(object): pattern = UserID.from_string(user_id).localpart if not pattern: - logger.warn("event_match condition with no pattern") + logger.warning("event_match condition with no pattern") return False # XXX: optimisation: cache our pattern regexps @@ -173,7 +173,7 @@ def _glob_matches(glob, value, word_boundary=False): regex_cache[(glob, word_boundary)] = r return r.search(value) except re.error: - logger.warn("Failed to parse glob to regex: %r", glob) + logger.warning("Failed to parse glob to regex: %r", glob) return False diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 9be37cd998..c8056b0c0c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -180,7 +180,7 @@ class ReplicationEndpoint(object): if e.code != 504 or not cls.RETRY_ON_TIMEOUT: raise - logger.warn("%s request timed out", cls.NAME) + logger.warning("%s request timed out", cls.NAME) # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index b5f5f13a62..cc1f249740 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -144,7 +144,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # The 'except' clause is very broad, but we need to # capture everything from DNS failures upwards # - logger.warn("Failed to reject invite: %s", e) + logger.warning("Failed to reject invite: %s", e) await self.store.locally_reject_invite(user_id, room_id) ret = {} diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index a44ceb00e7..563ce0fc53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -168,7 +168,7 @@ class ReplicationClientHandler(object): if self.connection: self.connection.send_command(cmd) else: - logger.warn("Queuing command as not connected: %r", cmd.NAME) + logger.warning("Queuing command as not connected: %r", cmd.NAME) self.pending_commands.append(cmd) def send_federation_ack(self, token): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 5ffdf2675d..b64f3f44b5 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -249,7 +249,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return handler(cmd) def close(self): - logger.warn("[%s] Closing connection", self.id()) + logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 939418ee2b..5c2a2eb593 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -286,7 +286,7 @@ class PurgeHistoryRestServlet(RestServlet): room_id, stream_ordering ) if not r: - logger.warn( + logger.warning( "[purge] purging events not possible: No event found " "(received_ts %i => stream_ordering %i)", ts, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8414af08cb..39a5c5e9de 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -221,7 +221,7 @@ class LoginRestServlet(RestServlet): medium, address ) if not user_id: - logger.warn( + logger.warning( "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 80cf7126a0..332d7138b1 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -71,7 +71,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User password resets have been disabled due to lack of email config" ) raise SynapseError( @@ -162,7 +162,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Password reset emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -183,7 +183,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -350,7 +350,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -441,7 +441,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -488,7 +488,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): def on_GET(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( @@ -515,7 +515,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 4f24a124a6..6c7d25d411 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -106,7 +106,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "Email registration has been disabled due to lack of email config" ) raise SynapseError( @@ -207,7 +207,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if not self.hs.config.account_threepid_delegate_msisdn: - logger.warn( + logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" ) @@ -266,7 +266,7 @@ class RegistrationSubmitTokenServlet(RestServlet): ) if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warn( + logger.warning( "User registration via email has been disabled due to lack of email config" ) raise SynapseError( @@ -287,7 +287,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: if next_link.startswith("file:///"): - logger.warn( + logger.warning( "Not redirecting to next_link as it is a local file: address" ) else: @@ -480,7 +480,7 @@ class RegisterRestServlet(RestServlet): # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params - logger.warn("Ignoring initial_device_display_name without password") + logger.warning("Ignoring initial_device_display_name without password") del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 541a6b0e10..ccd8b17b23 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -394,7 +394,7 @@ class SyncRestServlet(RestServlet): # We've had bug reports that events were coming down under the # wrong room. if event.room_id != room.room_id: - logger.warn( + logger.warning( "Event %r is under room %r instead of %r", event.event_id, room.room_id, diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index b972e152a9..bd9186fe50 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -363,7 +363,7 @@ class MediaRepository(object): }, ) except RequestSendFailed as e: - logger.warn( + logger.warning( "Request failed fetching remote media %s/%s: %r", server_name, media_id, @@ -372,7 +372,7 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except HttpResponseException as e: - logger.warn( + logger.warning( "HTTP error fetching remote media %s/%s: %s", server_name, media_id, @@ -383,10 +383,12 @@ class MediaRepository(object): raise SynapseError(502, "Failed to fetch remote media") except SynapseError: - logger.warn("Failed to fetch remote media %s/%s", server_name, media_id) + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) raise except NotRetryingDestination: - logger.warn("Not retrying destination %r", server_name) + logger.warning("Not retrying destination %r", server_name) raise SynapseError(502, "Failed to fetch remote media") except Exception: logger.exception( @@ -691,7 +693,7 @@ class MediaRepository(object): try: os.remove(full_path) except OSError as e: - logger.warn("Failed to remove file: %r", full_path) + logger.warning("Failed to remove file: %r", full_path) if e.errno == errno.ENOENT: pass else: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 094ebad770..5a25b6b3fc 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -136,7 +136,7 @@ class PreviewUrlResource(DirectServeResource): match = False continue if match: - logger.warn("URL %s blocked by url_blacklist entry %s", url, entry) + logger.warning("URL %s blocked by url_blacklist entry %s", url, entry) raise SynapseError( 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN ) @@ -208,7 +208,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s" % url) + logger.warning("Couldn't get dims for %s" % url) # define our OG response for this media elif _is_html(media_info["media_type"]): @@ -256,7 +256,7 @@ class PreviewUrlResource(DirectServeResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warn("Couldn't get dims for %s", og["og:image"]) + logger.warning("Couldn't get dims for %s", og["og:image"]) og["og:image"] = "mxc://%s/%s" % ( self.server_name, @@ -267,7 +267,7 @@ class PreviewUrlResource(DirectServeResource): else: del og["og:image"] else: - logger.warn("Failed to find any OG data in %s", url) + logger.warning("Failed to find any OG data in %s", url) og = {} logger.debug("Calculated OG for %s as %s", url, og) @@ -319,7 +319,7 @@ class PreviewUrlResource(DirectServeResource): ) except Exception as e: # FIXME: pass through 404s and other error messages nicely - logger.warn("Error downloading %s: %r", url, e) + logger.warning("Error downloading %s: %r", url, e) raise SynapseError( 500, @@ -400,7 +400,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) @@ -432,7 +432,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue try: @@ -448,7 +448,7 @@ class PreviewUrlResource(DirectServeResource): except OSError as e: # If the path doesn't exist, meh if e.errno != errno.ENOENT: - logger.warn("Failed to remove media: %r: %s", media_id, e) + logger.warning("Failed to remove media: %r: %s", media_id, e) continue removed_media.append(media_id) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 08329884ac..931ce79be8 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -182,7 +182,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks @@ -245,7 +245,7 @@ class ThumbnailResource(DirectServeResource): if file_path: yield respond_with_file(request, desired_type, file_path) else: - logger.warn("Failed to generate thumbnail") + logger.warning("Failed to generate thumbnail") respond_404(request) @defer.inlineCallbacks diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index c0e7f475c9..9fae2e0afe 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -83,7 +83,7 @@ class ResourceLimitsServerNotices(object): room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id) if not room_id: - logger.warn("Failed to get server notices room") + logger.warning("Failed to get server notices room") return yield self._check_and_set_tags(user_id, room_id) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f5906fcd54..1a2b7ebe25 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -494,7 +494,7 @@ class SQLBaseStore(object): exception_callbacks = [] if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warn("Starting db txn '%s' from sentinel context", desc) + logger.warning("Starting db txn '%s' from sentinel context", desc) try: result = yield self.runWithConnection( @@ -532,7 +532,7 @@ class SQLBaseStore(object): """ parent_context = LoggingContext.current_context() if parent_context == LoggingContext.sentinel: - logger.warn( + logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) parent_context = None @@ -719,7 +719,7 @@ class SQLBaseStore(object): raise # presumably we raced with another transaction: let's retry. - logger.warn( + logger.warning( "IntegrityError when upserting into %s; retrying: %s", table, e ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index f005c1ae0a..d76861cdc0 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -44,7 +44,7 @@ class PusherWorkerStore(SQLBaseStore): r["data"] = json.loads(dataJson) except Exception as e: - logger.warn( + logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], dataJson, diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 0e08497452..a59b8331e1 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -196,7 +196,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): " ON event_search USING GIN (vector)" ) except psycopg2.ProgrammingError as e: - logger.warn( + logger.warning( "Ignoring error %r when trying to switch from GIST to GIN", e ) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index b60a604474..5c4de2e69f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -309,7 +309,7 @@ class Linearizer(object): ) else: - logger.warn( + logger.warning( "Unexpected exception waiting for linearizer lock %r for key %r", self.name, key, diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 43fd65d693..da5077b471 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -107,7 +107,7 @@ def register_cache(cache_type, cache_name, cache, collect_callback=None): if collect_callback: collect_callback() except Exception as e: - logger.warn("Error calculating metrics for %s: %s", cache_name, e) + logger.warning("Error calculating metrics for %s: %s", cache_name, e) raise yield GaugeMetricFamily("__unused", "") diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 4b1bcdf23c..3286804322 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -119,7 +119,7 @@ class Measure(object): context = LoggingContext.current_context() if context != self.start_context: - logger.warn( + logger.warning( "Context has unexpectedly changed from '%s' to '%s'. (%r)", self.start_context, context, @@ -128,7 +128,7 @@ class Measure(object): return if not context: - logger.warn("Expected context. (%r)", self.name) + logger.warning("Expected context. (%r)", self.name) return current = context.get_resource_usage() @@ -140,7 +140,7 @@ class Measure(object): block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warn( + logger.warning( "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current ) diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py index 6c0f2bb0cf..207cd17c2a 100644 --- a/synapse/util/rlimit.py +++ b/synapse/util/rlimit.py @@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no): resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY) ) except (ValueError, resource.error) as e: - logger.warn("Failed to set file or core limit: %s", e) + logger.warning("Failed to set file or core limit: %s", e) -- cgit 1.5.1 From 020add50997f697c7847ac84b86b457ba2f3e32d Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 1 Nov 2019 02:43:24 +1100 Subject: Update black to 19.10b0 (#6304) * update version of black and also fix the mypy config being overridden --- changelog.d/6304.misc | 1 + contrib/experiments/test_messaging.py | 4 +-- mypy.ini | 11 ++++--- synapse/federation/sender/per_destination_queue.py | 11 ++++--- synapse/handlers/account_data.py | 7 ++-- synapse/handlers/appservice.py | 5 ++- synapse/handlers/e2e_keys.py | 37 ++++++++++++++-------- synapse/handlers/federation.py | 9 +++--- synapse/handlers/initial_sync.py | 4 +-- synapse/handlers/message.py | 14 ++++---- synapse/handlers/pagination.py | 13 ++++---- synapse/handlers/register.py | 4 +-- synapse/handlers/room.py | 29 +++++++++-------- synapse/handlers/room_member.py | 35 ++++++++++---------- synapse/handlers/search.py | 12 +++---- synapse/handlers/stats.py | 5 ++- synapse/handlers/sync.py | 16 ++++++---- synapse/logging/_structured.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 7 ++-- synapse/push/emailpusher.py | 14 ++++---- synapse/push/httppusher.py | 14 ++++---- synapse/push/pusherpool.py | 4 +-- synapse/rest/client/v1/login.py | 13 ++++---- synapse/rest/client/v2_alpha/account.py | 4 +-- synapse/rest/client/v2_alpha/register.py | 4 +-- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/server.pyi | 16 +++++----- synapse/storage/data_stores/main/__init__.py | 4 +-- .../storage/data_stores/main/event_push_actions.py | 2 +- synapse/storage/data_stores/main/events.py | 8 ++--- .../storage/data_stores/main/events_bg_updates.py | 2 +- synapse/storage/data_stores/main/group_server.py | 4 +-- .../data_stores/main/monthly_active_users.py | 2 +- synapse/storage/data_stores/main/push_rule.py | 2 +- synapse/storage/data_stores/main/registration.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/storage/data_stores/main/search.py | 2 +- synapse/storage/data_stores/main/state.py | 20 ++++++------ synapse/storage/data_stores/main/stats.py | 4 +-- synapse/storage/util/id_generators.py | 2 +- tox.ini | 4 +-- 41 files changed, 191 insertions(+), 166 deletions(-) create mode 100644 changelog.d/6304.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6304.misc b/changelog.d/6304.misc new file mode 100644 index 0000000000..20372b4f7c --- /dev/null +++ b/changelog.d/6304.misc @@ -0,0 +1 @@ +Update the version of black used to 19.10b0. diff --git a/contrib/experiments/test_messaging.py b/contrib/experiments/test_messaging.py index 6b22400a60..3bbbcfa1b4 100644 --- a/contrib/experiments/test_messaging.py +++ b/contrib/experiments/test_messaging.py @@ -78,7 +78,7 @@ class InputOutput(object): m = re.match("^join (\S+)$", line) if m: # The `sender` wants to join a room. - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("%s joining %s" % (self.user, room_name)) self.server.join_room(room_name, self.user, self.user) # self.print_line("OK.") @@ -105,7 +105,7 @@ class InputOutput(object): m = re.match("^backfill (\S+)$", line) if m: # we want to backfill a room - room_name, = m.groups() + (room_name,) = m.groups() self.print_line("backfill %s" % room_name) self.server.backfill(room_name) return diff --git a/mypy.ini b/mypy.ini index ffadaddc0b..1d77c0ecc8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,8 +1,11 @@ [mypy] -namespace_packages=True -plugins=mypy_zope:plugin -follow_imports=skip -mypy_path=stubs +namespace_packages = True +plugins = mypy_zope:plugin +follow_imports = normal +check_untyped_defs = True +show_error_codes = True +show_traceback = True +mypy_path = stubs [mypy-zope] ignore_missing_imports = True diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index cc75c39476..b754a09d7a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -192,15 +192,16 @@ class PerDestinationQueue(object): # We have to keep 2 free slots for presence and rr_edus limit = MAX_EDUS_PER_TRANSACTION - 2 - device_update_edus, dev_list_id = ( - yield self._get_device_update_edus(limit) + device_update_edus, dev_list_id = yield self._get_device_update_edus( + limit ) limit -= len(device_update_edus) - to_device_edus, device_stream_id = ( - yield self._get_to_device_message_edus(limit) - ) + ( + to_device_edus, + device_stream_id, + ) = yield self._get_to_device_message_edus(limit) pending_edus = device_update_edus + to_device_edus diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 38bc67191c..2d7e6df6e4 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -38,9 +38,10 @@ class AccountDataEventSource(object): {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} ) - account_data, room_account_data = ( - yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) - ) + ( + account_data, + room_account_data, + ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3e9b298154..fe62f78e67 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -73,7 +73,10 @@ class ApplicationServicesHandler(object): try: limit = 100 while True: - upper_bound, events = yield self.store.get_new_events_for_appservice( + ( + upper_bound, + events, + ) = yield self.store.get_new_events_for_appservice( self.current_max, limit ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5ea54f60be..0449034a4e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -119,9 +119,10 @@ class E2eKeysHandler(object): else: query_list.append((user_id, None)) - user_ids_not_in_cache, remote_results = ( - yield self.store.get_user_devices_from_cache(query_list) - ) + ( + user_ids_not_in_cache, + remote_results, + ) = yield self.store.get_user_devices_from_cache(query_list) for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) for device_id, device in iteritems(devices): @@ -688,17 +689,21 @@ class E2eKeysHandler(object): try: # get our self-signing key to verify the signatures - _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "self_signing" - ) + ( + _, + self_signing_key_id, + self_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing") # get our master key, since we may have received a signature of it. # We need to fetch it here so that we know what its key ID is, so # that we can check if a signature that was sent is a signature of # the master key or of a device - master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "master" - ) + ( + master_key, + _, + master_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master") # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to compare with what @@ -838,9 +843,11 @@ class E2eKeysHandler(object): try: # get our user-signing key to verify the signatures - user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "user_signing" - ) + ( + user_signing_key, + user_signing_key_id, + user_signing_verify_key, + ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing") except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): @@ -859,7 +866,11 @@ class E2eKeysHandler(object): try: # get the target user's master key, to make sure it matches # what was sent - master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key( + ( + master_key, + master_key_id, + _, + ) = yield self._get_e2e_cross_signing_verify_key( target_user, "master", user_id ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d2d9f8c26a..a932d3085f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -352,10 +352,11 @@ class FederationHandler(BaseHandler): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state, got_auth_chain = ( - yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + ( + remote_state, + got_auth_chain, + ) = yield self.federation_client.get_state_for_room( + origin, room_id, p ) # we want the state *after* p; get_state_for_room returns the diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 49c9e031f9..81dce96f4b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -128,8 +128,8 @@ class InitialSyncHandler(BaseHandler): tags_by_room = yield self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) + account_data, account_data_by_room = yield self.store.get_account_data_for_user( + user_id ) public_room_ids = yield self.store.get_public_room_ids() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0d546d2487..d682dc2b7a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -76,9 +76,10 @@ class MessageHandler(object): Raises: SynapseError if something went wrong. """ - membership, membership_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: data = yield self.state.get_current_state(room_id, event_type, state_key) @@ -153,9 +154,10 @@ class MessageHandler(object): % (user_id, room_id, at_token), ) else: - membership, membership_event_id = ( - yield self.auth.check_in_room_or_world_readable(room_id, user_id) - ) + ( + membership, + membership_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index b7185fe7a0..97f15a1c32 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -212,9 +212,10 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") with (yield self.pagination_lock.read(room_id)): - membership, member_event_id = yield self.auth.check_in_room_or_world_readable( - room_id, user_id - ) + ( + membership, + member_event_id, + ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -297,10 +298,8 @@ class PaginationHandler(object): } if state: - chunk["state"] = ( - yield self._event_serializer.serialize_events( - state, time_now, as_client_event=as_client_event - ) + chunk["state"] = yield self._event_serializer.serialize_events( + state, time_now, as_client_event=as_client_event ) return chunk diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 53410f120b..cff6b0d375 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -396,8 +396,8 @@ class RegistrationHandler(BaseHandler): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = ( - yield room_member_handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_alias ) room_id = room_id.to_string() else: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 650bd28abb..0182e5b432 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -147,21 +147,22 @@ class RoomCreationHandler(BaseHandler): # we create and auth the tombstone event before properly creating the new # room, to check our user has perms in the old room. - tombstone_event, tombstone_context = ( - yield self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - }, + ( + tombstone_event, + tombstone_context, + ) = yield self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, }, - token_id=requester.access_token_id, - ) + }, + token_id=requester.access_token_id, ) old_room_version = yield self.store.get_room_version(old_room_id) yield self.auth.check_from_context( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 380e2fad5e..9a940d2c05 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -759,22 +759,25 @@ class RoomMemberHandler(object): if room_avatar_event: room_avatar_url = room_avatar_event.content.get("url", "") - token, public_keys, fallback_public_key, display_name = ( - yield self.identity_handler.ask_id_server_for_third_party_invite( - requester=requester, - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url, - id_access_token=id_access_token, - ) + ( + token, + public_keys, + fallback_public_key, + display_name, + ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + requester=requester, + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url, + id_access_token=id_access_token, ) yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index f4d8a60774..56ed262a1f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -396,15 +396,11 @@ class SearchHandler(BaseHandler): time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = ( - yield self._event_serializer.serialize_events( - context["events_before"], time_now - ) + context["events_before"] = yield self._event_serializer.serialize_events( + context["events_before"], time_now ) - context["events_after"] = ( - yield self._event_serializer.serialize_events( - context["events_after"], time_now - ) + context["events_after"] = yield self._event_serializer.serialize_events( + context["events_after"], time_now ) state_results = {} diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 26bc276692..7f7d56390e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -108,7 +108,10 @@ class StatsHandler(StateDeltasHandler): user_deltas = {} # Then count deltas for total_events and total_event_bytes. - room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes( + ( + room_count, + user_count, + ) = yield self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 43a082dcda..b536d410e5 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1206,10 +1206,11 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, since_token.account_data_key - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_updated_account_data_for_user( + user_id, since_token.account_data_key ) push_rules_changed = yield self.store.have_push_rules_changed_for_user( @@ -1221,9 +1222,10 @@ class SyncHandler(object): sync_config.user ) else: - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(sync_config.user.to_string()) - ) + ( + account_data, + account_data_by_room, + ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) account_data["m.push_rules"] = yield self.push_rules_for_user( sync_config.user diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 3220e985a9..334ddaf39a 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -185,7 +185,7 @@ DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} def parse_drain_configs( - drains: dict + drains: dict, ) -> typing.Generator[DrainConfiguration, None, None]: """ Parse the drain configurations. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 2bbdd11941..1ba7bcd4d8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -149,9 +149,10 @@ class BulkPushRuleEvaluator(object): room_members = yield self.store.get_joined_users_from_context(event, context) - (power_levels, sender_power_level) = ( - yield self._get_power_levels_and_sender_level(event, context) - ) + ( + power_levels, + sender_power_level, + ) = yield self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 42e5b0c0a5..8c818a86bf 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -234,14 +234,12 @@ class EmailPusher(object): return self.last_stream_ordering = last_stream_ordering - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 9a1bb64887..7dde2ad055 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -211,14 +211,12 @@ class HttpPusher(object): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = ( - yield self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_stream_ordering, - self.clock.time_msec(), - ) + pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering, + self.clock.time_msec(), ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 08e840fdc2..0f6992202d 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -103,9 +103,7 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = ( - yield self.store.get_latest_push_action_stream_ordering() - ) + last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() yield self.store.add_pusher( user_id=user_id, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 39a5c5e9de..00a7dd6d09 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -203,10 +203,11 @@ class LoginRestServlet(RestServlet): address = address.lower() # Check for login providers that support 3pid login types - canonical_user_id, callback_3pid = ( - yield self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) + ( + canonical_user_id, + callback_3pid, + ) = yield self.auth_handler.check_password_provider_3pid( + medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded @@ -280,8 +281,8 @@ class LoginRestServlet(RestServlet): def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = ( - yield auth_handler.validate_short_term_login_token_and_get_user_id(token) + user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + token ) result = yield self._register_device_with_callback(user_id, login_submission) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 332d7138b1..f26eae794c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -148,7 +148,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_password_reset_template_failure_html], ) @@ -479,7 +479,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_add_threepid_template_failure_html], ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 6c7d25d411..91db923814 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -247,13 +247,13 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.failure_email_template, = load_jinja2_templates( + (self.failure_email_template,) = load_jinja2_templates( self.config.email_template_dir, [self.config.email_registration_template_failure_html], ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 55580bc59e..e7fc3f0431 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -102,7 +102,7 @@ class RemoteKey(DirectServeResource): @wrap_json_request_handler async def _async_render_GET(self, request): if len(request.postpath) == 1: - server, = request.postpath + (server,) = request.postpath query = {server.decode("ascii"): {}} elif len(request.postpath) == 2: server, key_id = request.postpath diff --git a/synapse/server.pyi b/synapse/server.pyi index 16f8f6b573..83d1f11283 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -39,7 +39,7 @@ class HomeServer(object): def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler: pass def get_deactivate_account_handler( - self + self, ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler: pass def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler: @@ -47,32 +47,32 @@ class HomeServer(object): def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler: pass def get_event_creation_handler( - self + self, ) -> synapse.handlers.message.EventCreationHandler: pass def get_set_password_handler( - self + self, ) -> synapse.handlers.set_password.SetPasswordHandler: pass def get_federation_sender(self) -> synapse.federation.sender.FederationSender: pass def get_federation_transport_client( - self + self, ) -> synapse.federation.transport.client.TransportLayerClient: pass def get_media_repository_resource( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource: pass def get_media_repository( - self + self, ) -> synapse.rest.media.v1.media_repository.MediaRepository: pass def get_server_notices_manager( - self + self, ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager: pass def get_server_notices_sender( - self + self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index b185ba0b3e..60ae01d972 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -317,7 +317,7 @@ class DataStore( ) u """ txn.execute(sql, (time_from,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count def count_r30_users(self): @@ -396,7 +396,7 @@ class DataStore( txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs)) - count, = txn.fetchone() + (count,) = txn.fetchone() results["all"] = count return results diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 22025effbc..04ce21ac66 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -863,7 +863,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) stream_row = txn.fetchone() if stream_row: - offset_stream_ordering, = stream_row + (offset_stream_ordering,) = stream_row rotate_to_stream_ordering = min( self.stream_ordering_day_ago, offset_stream_ordering ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 64a8a05279..aafc2007d3 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1125,7 +1125,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_messages", _count_messages) @@ -1146,7 +1146,7 @@ class EventsStore( """ txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) @@ -1161,7 +1161,7 @@ class EventsStore( AND stream_ordering > ? """ txn.execute(sql, (self.stream_ordering_day_ago,)) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_daily_active_rooms", _count) @@ -1646,7 +1646,7 @@ class EventsStore( """, (room_id,), ) - min_depth, = txn.fetchone() + (min_depth,) = txn.fetchone() logger.info("[purge] updating room_depth to %d", min_depth) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 31ea6f917f..51352b9966 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -438,7 +438,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if not rows: return 0 - upper_event_id, = rows[-1] + (upper_event_id,) = rows[-1] # Update the redactions with the received_ts. # diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index aeae5a2b28..b3a2771f1b 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -249,7 +249,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND category_id = ? """ txn.execute(sql, (group_id, category_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} @@ -509,7 +509,7 @@ class GroupServerStore(SQLBaseStore): WHERE group_id = ? AND role_id = ? """ txn.execute(sql, (group_id, role_id)) - order, = txn.fetchone() + (order,) = txn.fetchone() if existing: to_update = {} diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index e6ee1e4aaa..b41c3d317a 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -171,7 +171,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" txn.execute(sql) - count, = txn.fetchone() + (count,) = txn.fetchone() return count return self.runInteraction("count_users", _count_users) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index cd95f1ce60..b520062d84 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -143,7 +143,7 @@ class PushRulesWorkerStore( " WHERE user_id = ? AND ? < stream_id" ) txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() + (count,) = txn.fetchone() return bool(count) return self.runInteraction( diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 6c5b29288a..f70d41ecab 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -459,7 +459,7 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE appservice_id IS NULL """ ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count ret = yield self.runInteraction("count_users", _count_users) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index bc04bfd7d4..2af24a20b7 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -927,7 +927,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): if not row or not row[0]: return processed, True - next_room, = row + (next_room,) = row sql = """ UPDATE current_state_events diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index a59b8331e1..d1d7c6863d 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -672,7 +672,7 @@ class SearchStore(SearchBackgroundUpdateStore): ) ) txn.execute(query, (value, search_query)) - headline, = txn.fetchall()[0] + (headline,) = txn.fetchall()[0] # Now we need to pick the possible highlights out of the haedline # result. diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9b2207075b..3132848034 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -725,16 +725,18 @@ class StateGroupWorkerStore( member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - non_member_state, incomplete_groups_nm, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter ) - member_state, incomplete_groups_m, = ( - yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter ) state = dict(non_member_state) @@ -1076,7 +1078,7 @@ class StateBackgroundUpdateStore( " WHERE id < ? AND room_id = ?", (state_group, room_id), ) - prev_group, = txn.fetchone() + (prev_group,) = txn.fetchone() new_last_state_group = state_group if prev_group: diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 4d59b7833f..45b3de7d56 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -773,7 +773,7 @@ class StatsStore(StateDeltasStore): (room_id,), ) - current_state_events_count, = txn.fetchone() + (current_state_events_count,) = txn.fetchone() users_in_room = self.get_users_in_room_txn(txn, room_id) @@ -863,7 +863,7 @@ class StatsStore(StateDeltasStore): """, (user_id,), ) - count, = txn.fetchone() + (count,) = txn.fetchone() return count, pos joined_rooms, pos = yield self.runInteraction( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index cbb0a4810a..9d851beaa5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1): cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) - val, = cur.fetchone() + (val,) = cur.fetchone() cur.close() current_id = int(val) if val else step return (max if step > 0 else min)(current_id, step) diff --git a/tox.ini b/tox.ini index 50b6afe611..afe9bc909b 100644 --- a/tox.ini +++ b/tox.ini @@ -114,7 +114,7 @@ skip_install = True basepython = python3.6 deps = flake8 - black==19.3b0 # We pin so that our tests don't start failing on new releases of black. + black==19.10b0 # We pin so that our tests don't start failing on new releases of black. commands = python -m black --check --diff . /bin/sh -c "flake8 synapse tests scripts scripts-dev synctl {env:PEP8SUFFIX:}" @@ -167,6 +167,6 @@ deps = env = MYPYPATH = stubs/ extras = all -commands = mypy --show-traceback --check-untyped-defs --show-error-codes --follow-imports=normal \ +commands = mypy \ synapse/logging/ \ synapse/config/ -- cgit 1.5.1 From 09957ce0e4dcfd84c2de4039653059faae03065b Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 4 Nov 2019 17:09:22 +0000 Subject: Implement per-room message retention policies --- changelog.d/5815.feature | 1 + docs/sample_config.yaml | 63 ++++ synapse/api/constants.py | 2 + synapse/config/server.py | 172 +++++++++++ synapse/events/validator.py | 100 ++++++- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 4 +- synapse/handlers/pagination.py | 111 +++++++ synapse/storage/data_stores/main/events.py | 3 + synapse/storage/data_stores/main/room.py | 252 ++++++++++++++++ .../main/schema/delta/56/room_retention.sql | 33 +++ synapse/visibility.py | 17 ++ tests/rest/client/test_retention.py | 320 +++++++++++++++++++++ 13 files changed, 1074 insertions(+), 6 deletions(-) create mode 100644 changelog.d/5815.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/room_retention.sql create mode 100644 tests/rest/client/test_retention.py (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/5815.feature b/changelog.d/5815.feature new file mode 100644 index 0000000000..ca4df4e7f6 --- /dev/null +++ b/changelog.d/5815.feature @@ -0,0 +1 @@ +Implement per-room message retention policies. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d2f4aff826..87fba27d13 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -328,6 +328,69 @@ listeners: # #user_ips_max_age: 14d +# Message retention policy at the server level. +# +# Room admins and mods can define a retention period for their rooms using the +# 'm.room.retention' state event, and server admins can cap this period by setting +# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. +# +# If this feature is enabled, Synapse will regularly look for and purge events +# which are older than the room's maximum retention period. Synapse will also +# filter events received over federation so that events that should have been +# purged are ignored and not stored again. +# +retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h + ## TLS ## diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 49c4b85054..e3f086f1c3 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -94,6 +94,8 @@ class EventTypes(object): ServerACL = "m.room.server_acl" Pinned = "m.room.pinned_events" + Retention = "m.room.retention" + class RejectedReason(object): AUTH_ERROR = "auth_error" diff --git a/synapse/config/server.py b/synapse/config/server.py index d556df308d..aa93a416f1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -246,6 +246,115 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + retention_config = config.get("retention") + if retention_config is None: + retention_config = {} + + self.retention_enabled = retention_config.get("enabled", False) + + retention_default_policy = retention_config.get("default_policy") + + if retention_default_policy is not None: + self.retention_default_min_lifetime = retention_default_policy.get( + "min_lifetime" + ) + if self.retention_default_min_lifetime is not None: + self.retention_default_min_lifetime = self.parse_duration( + self.retention_default_min_lifetime + ) + + self.retention_default_max_lifetime = retention_default_policy.get( + "max_lifetime" + ) + if self.retention_default_max_lifetime is not None: + self.retention_default_max_lifetime = self.parse_duration( + self.retention_default_max_lifetime + ) + + if ( + self.retention_default_min_lifetime is not None + and self.retention_default_max_lifetime is not None + and ( + self.retention_default_min_lifetime + > self.retention_default_max_lifetime + ) + ): + raise ConfigError( + "The default retention policy's 'min_lifetime' can not be greater" + " than its 'max_lifetime'" + ) + else: + self.retention_default_min_lifetime = None + self.retention_default_max_lifetime = None + + self.retention_allowed_lifetime_min = retention_config.get("allowed_lifetime_min") + if self.retention_allowed_lifetime_min is not None: + self.retention_allowed_lifetime_min = self.parse_duration( + self.retention_allowed_lifetime_min + ) + + self.retention_allowed_lifetime_max = retention_config.get("allowed_lifetime_max") + if self.retention_allowed_lifetime_max is not None: + self.retention_allowed_lifetime_max = self.parse_duration( + self.retention_allowed_lifetime_max + ) + + if ( + self.retention_allowed_lifetime_min is not None + and self.retention_allowed_lifetime_max is not None + and self.retention_allowed_lifetime_min > self.retention_allowed_lifetime_max + ): + raise ConfigError( + "Invalid retention policy limits: 'allowed_lifetime_min' can not be" + " greater than 'allowed_lifetime_max'" + ) + + self.retention_purge_jobs = [] + for purge_job_config in retention_config.get("purge_jobs", []): + interval_config = purge_job_config.get("interval") + + if interval_config is None: + raise ConfigError( + "A retention policy's purge jobs configuration must have the" + " 'interval' key set." + ) + + interval = self.parse_duration(interval_config) + + shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") + + if shortest_max_lifetime is not None: + shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) + + longest_max_lifetime = purge_job_config.get("longest_max_lifetime") + + if longest_max_lifetime is not None: + longest_max_lifetime = self.parse_duration(longest_max_lifetime) + + if ( + shortest_max_lifetime is not None + and longest_max_lifetime is not None + and shortest_max_lifetime > longest_max_lifetime + ): + raise ConfigError( + "A retention policy's purge jobs configuration's" + " 'shortest_max_lifetime' value can not be greater than its" + " 'longest_max_lifetime' value." + ) + + self.retention_purge_jobs.append({ + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + }) + + if not self.retention_purge_jobs: + self.retention_purge_jobs = [{ + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + }] + self.listeners = [] # type: List[dict] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): @@ -761,6 +870,69 @@ class ServerConfig(Config): # Defaults to `28d`. Set to `null` to disable clearing out of old rows. # #user_ips_max_age: 14d + + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h """ % locals() ) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 272426e105..9b90c9ce04 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import integer_types, string_types from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID class EventValidator(object): - def validate_new(self, event): + def validate_new(self, event, config): """Validates the event has roughly the right format Args: - event (FrozenEvent) + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. """ self.validate_builder(event) @@ -67,6 +68,99 @@ class EventValidator(object): Codes.INVALID_PARAM, ) + if event.type == EventTypes.Retention: + self._validate_retention(event, config) + + def _validate_retention(self, event, config): + """Checks that an event that defines the retention policy for a room respects the + boundaries imposed by the server's administrator. + + Args: + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. + """ + min_lifetime = event.content.get("min_lifetime") + max_lifetime = event.content.get("max_lifetime") + + if min_lifetime is not None: + if not isinstance(min_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'min_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and min_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be lower than the minimum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and min_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if max_lifetime is not None: + if not isinstance(max_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'max_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and max_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be lower than the minimum allowed value" + " enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and max_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + min_lifetime is not None + and max_lifetime is not None + and min_lifetime > max_lifetime + ): + raise SynapseError( + code=400, + msg="'min_lifetime' can't be greater than 'max_lifetime", + errcode=Codes.BAD_JSON, + ) + def validate_builder(self, event): """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8cafcfdab0..3994137d18 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2454,7 +2454,7 @@ class FederationHandler(BaseHandler): room_version, event_dict, event, context ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even # though the sender isn't a local user. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d682dc2b7a..155ed6e06a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -417,7 +417,7 @@ class EventCreationHandler(object): 403, "You must be in the room to create an alias for it" ) - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) return (event, context) @@ -634,7 +634,7 @@ class EventCreationHandler(object): if requester: context.app_service = requester.app_service - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 97f15a1c32..e1800177fa 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,12 +15,15 @@ # limitations under the License. import logging +from six import iteritems + from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock @@ -80,6 +83,114 @@ class PaginationHandler(object): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() + self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + + if hs.config.retention_enabled: + # Run the purge jobs described in the configuration file. + for job in hs.config.retention_purge_jobs: + self.clock.looping_call( + run_as_background_process, + job["interval"], + "purge_history_for_rooms_in_range", + self.purge_history_for_rooms_in_range, + job["shortest_max_lifetime"], + job["longest_max_lifetime"], + ) + + @defer.inlineCallbacks + def purge_history_for_rooms_in_range(self, min_ms, max_ms): + """Purge outdated events from rooms within the given retention range. + + If a default retention policy is defined in the server's configuration and its + 'max_lifetime' is within this range, also targets rooms which don't have a + retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, it means that the range has no + lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, it means that the range has no + upper limit. + """ + # We want the storage layer to to include rooms with no retention policy in its + # return value only if a default retention policy is defined in the server's + # configuration and that policy's 'max_lifetime' is either lower (or equal) than + # max_ms or higher than min_ms (or both). + if self._retention_default_max_lifetime is not None: + include_null = True + + if min_ms is not None and min_ms >= self._retention_default_max_lifetime: + # The default max_lifetime is lower than (or equal to) min_ms. + include_null = False + + if max_ms is not None and max_ms < self._retention_default_max_lifetime: + # The default max_lifetime is higher than max_ms. + include_null = False + else: + include_null = False + + rooms = yield self.store.get_rooms_for_retention_period_in_range( + min_ms, max_ms, include_null + ) + + for room_id, retention_policy in iteritems(rooms): + if room_id in self._purges_in_progress_by_room: + logger.warning( + "[purge] not purging room %s as there's an ongoing purge running" + " for this room", + room_id, + ) + continue + + max_lifetime = retention_policy["max_lifetime"] + + if max_lifetime is None: + # If max_lifetime is None, it means that include_null equals True, + # therefore we can safely assume that there is a default policy defined + # in the server's configuration. + max_lifetime = self._retention_default_max_lifetime + + # Figure out what token we should start purging at. + ts = self.clock.time_msec() - max_lifetime + + stream_ordering = ( + yield self.store.find_first_stream_ordering_after_ts(ts) + ) + + r = ( + yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, + ) + ) + if not r: + logger.warning( + "[purge] purging events not possible: No event found " + "(ts %i => stream_ordering %i)", + ts, stream_ordering, + ) + continue + + (stream, topo, _event_id) = r + token = "t%d-%d" % (topo, stream) + + purge_id = random_string(16) + + self._purges_by_id[purge_id] = PurgeStatus() + + logger.info( + "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id) + ) + + # We want to purge everything, including local events, and to run the purge in + # the background so that it's not blocking any other operation apart from + # other purges in the same room. + run_as_background_process( + "_purge_history", + self._purge_history, + purge_id, room_id, token, True, + ) + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 301f8ea128..b332a42d82 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -929,6 +929,9 @@ class EventsStore( elif event.type == EventTypes.Redaction: # Insert into the redactions table. self._store_redaction(txn, event) + elif event.type == EventTypes.Retention: + # Update the room_retention table. + self._store_retention_policy_for_room_txn(txn, event) self._handle_event_relations(txn, event) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 67bb1b6f60..54a7d24c73 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -19,10 +19,13 @@ import logging import re from typing import Optional, Tuple +from six import integer_types + from canonicaljson import json from twisted.internet import defer +from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore @@ -302,6 +305,85 @@ class RoomWorkerStore(SQLBaseStore): class RoomStore(RoomWorkerStore, SearchStore): + def __init__(self, db_conn, hs): + super(RoomStore, self).__init__(db_conn, hs) + + self.config = hs.config + + self.register_background_update_handler( + "insert_room_retention", self._background_insert_retention, + ) + + @defer.inlineCallbacks + def _background_insert_retention(self, progress, batch_size): + """Retrieves a list of all rooms within a range and inserts an entry for each of + them into the room_retention table. + NULLs the property's columns if missing from the retention event in the room's + state (or NULLs all of them if there's no retention event in the room's state), + so that we fall back to the server's retention policy. + """ + + last_room = progress.get("room_id", "") + + def _background_insert_retention_txn(txn): + txn.execute( + """ + SELECT state.room_id, state.event_id, events.json + FROM current_state_events as state + LEFT JOIN event_json AS events ON (state.event_id = events.event_id) + WHERE state.room_id > ? AND state.type = '%s' + ORDER BY state.room_id ASC + LIMIT ?; + """ % EventTypes.Retention, + (last_room, batch_size) + ) + + rows = self.cursor_to_dict(txn) + + if not rows: + return True + + for row in rows: + if not row["json"]: + retention_policy = {} + else: + ev = json.loads(row["json"]) + retention_policy = json.dumps(ev["content"]) + + self._simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": row["room_id"], + "event_id": row["event_id"], + "min_lifetime": retention_policy.get("min_lifetime"), + "max_lifetime": retention_policy.get("max_lifetime"), + } + ) + + logger.info("Inserted %d rows into room_retention", len(rows)) + + self._background_update_progress_txn( + txn, "insert_room_retention", { + "room_id": rows[-1]["room_id"], + } + ) + + if batch_size > len(rows): + return True + else: + return False + + end = yield self.runInteraction( + "insert_room_retention", + _background_insert_retention_txn, + ) + + if end: + yield self._end_background_update("insert_room_retention") + + defer.returnValue(batch_size) + @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): """Stores a room. @@ -502,6 +584,37 @@ class RoomStore(RoomWorkerStore, SearchStore): txn, event, "content.body", event.content["body"] ) + def _store_retention_policy_for_room_txn(self, txn, event): + if ( + hasattr(event, "content") + and ("min_lifetime" in event.content or "max_lifetime" in event.content) + ): + if ( + ("min_lifetime" in event.content and not isinstance( + event.content.get("min_lifetime"), integer_types + )) + or ("max_lifetime" in event.content and not isinstance( + event.content.get("max_lifetime"), integer_types + )) + ): + # Ignore the event if one of the value isn't an integer. + return + + self._simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) + + self._invalidate_cache_and_stream( + txn, self.get_retention_policy_for_room, (event.room_id,) + ) + def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): @@ -683,3 +796,142 @@ class RoomStore(RoomWorkerStore, SearchStore): remote_media_mxcs.append((hostname, media_id)) return local_media_mxcs, remote_media_mxcs + + @defer.inlineCallbacks + def get_rooms_for_retention_period_in_range(self, min_ms, max_ms, include_null=False): + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null (bool): Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + dict[str, dict]: The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + rooms = yield self.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + + defer.returnValue(rooms) + + @cachedInlineCallbacks() + def get_retention_policy_for_room(self, room_id): + """Get the retention policy for a given room. + + If no retention policy has been found for this room, returns a policy defined + by the configured default policy (which has None as both the 'min_lifetime' and + the 'max_lifetime' if no default policy has been defined in the server's + configuration). + + Args: + room_id (str): The ID of the room to get the retention policy of. + + Returns: + dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + """ + + def get_retention_policy_for_room_txn(txn): + txn.execute( + """ + SELECT min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + WHERE room_id = ?; + """, + (room_id,) + ) + + return self.cursor_to_dict(txn) + + ret = yield self.runInteraction( + "get_retention_policy_for_room", + get_retention_policy_for_room_txn, + ) + + # If we don't know this room ID, ret will be None, in this case return the default + # policy. + if not ret: + defer.returnValue({ + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + }) + + row = ret[0] + + # If one of the room's policy's attributes isn't defined, use the matching + # attribute from the default policy. + # The default values will be None if no default policy has been defined, or if one + # of the attributes is missing from the default policy. + if row["min_lifetime"] is None: + row["min_lifetime"] = self.config.retention_default_min_lifetime + + if row["max_lifetime"] is None: + row["max_lifetime"] = self.config.retention_default_max_lifetime + + defer.returnValue(row) diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql new file mode 100644 index 0000000000..ee6cdf7a14 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql @@ -0,0 +1,33 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks the retention policy of a room. +-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in +-- the room's retention policy state event. +-- If a room doesn't have a retention policy state event in its state, both max_lifetime +-- and min_lifetime are NULL. +CREATE TABLE IF NOT EXISTS room_retention( + room_id TEXT, + event_id TEXT, + min_lifetime BIGINT, + max_lifetime BIGINT, + + PRIMARY KEY(room_id, event_id) +); + +CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('insert_room_retention', '{}'); diff --git a/synapse/visibility.py b/synapse/visibility.py index 8c843febd8..4498c156bc 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -86,6 +86,14 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + room_ids = set(e.room_id for e in events) + retention_policies = {} + + for room_id in room_ids: + retention_policies[room_id] = yield storage.main.get_retention_policy_for_room( + room_id + ) + def allowed(event): """ Args: @@ -103,6 +111,15 @@ def filter_events_for_client( if not event.is_state() and event.sender in ignore_list: return None + retention_policy = retention_policies[event.room_id] + max_lifetime = retention_policy.get("max_lifetime") + + if max_lifetime is not None: + oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + + if event.origin_server_ts < oldest_allowed_ts: + return None + if event.event_id in always_include_ids: return event diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644 index 0000000000..41ea9db689 --- /dev/null +++ b/tests/rest/client/test_retention.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 4, + }, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_hour_ms, + }, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": lifetime, + }, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + storage = self.hs.get_storage() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success( + store.get_event( + resp.get("event_id") + ) + )) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success( + store.get_event( + valid_event_id + ) + )) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success(filter_events_for_client( + storage, self.user_id, events + )) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["default_room_version"] = "1" + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, + federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={ + "max_lifetime": one_day_ms * 35, + }, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send( + room_id=room_id, + body="1", + tok=self.token, + ) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual(expired_event.get("content", {}).get("body"), "1", expired_event) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send( + room_id=room_id, + body="2", + tok=self.token, + ) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual(first_event.get("content", {}).get("body"), "1", first_event) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body -- cgit 1.5.1 From 708cef88cfbf8dd6df44d2da4ab4dbc7eb584f74 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 28 Nov 2019 19:26:13 +0000 Subject: Discard retention policies when retrieving state Purge jobs don't delete the latest event in a room in order to keep the forward extremity and not break the room. On the other hand, get_state_events, when given an at_token argument calls filter_events_for_client to know if the user can see the event that matches that (sync) token. That function uses the retention policies of the events it's given to filter out those that are too old from a client's view. Some clients, such as Riot, when loading a room, request the list of members for the latest sync token it knows about, and get confused to the point of refusing to send any message if the server tells it that it can't get that information. This can happen very easily with the message retention feature turned on and a room with low activity so that the last event sent becomes too old according to the room's retention policy. An easy and clean fix for that issue is to discard the room's retention policies when retrieving state. --- synapse/handlers/message.py | 2 +- synapse/visibility.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 9 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 155ed6e06a..3b0156f516 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -138,7 +138,7 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.storage, user_id, last_events + self.storage, user_id, last_events, apply_retention_policies=False ) event = last_events[0] diff --git a/synapse/visibility.py b/synapse/visibility.py index 4d4141dacc..7b037eeb0c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -44,7 +44,8 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset(), + apply_retention_policies=True, ): """ Check which events a user is allowed to see @@ -59,6 +60,10 @@ def filter_events_for_client( events always_include_ids (set(event_id)): set of event ids to specifically include (unless sender is ignored) + apply_retention_policies (bool): Whether to filter out events that's older than + allowed by the room's retention policy. Useful when this function is called + to e.g. check whether a user should be allowed to see the state at a given + event rather than to know if it should send an event to a user's client(s). Returns: Deferred[list[synapse.events.EventBase]] @@ -86,13 +91,14 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) - room_ids = set(e.room_id for e in events) - retention_policies = {} + if apply_retention_policies: + room_ids = set(e.room_id for e in events) + retention_policies = {} - for room_id in room_ids: - retention_policies[room_id] = yield storage.main.get_retention_policy_for_room( - room_id - ) + for room_id in room_ids: + retention_policies[room_id] = ( + yield storage.main.get_retention_policy_for_room(room_id) + ) def allowed(event): """ @@ -113,7 +119,7 @@ def filter_events_for_client( # Don't try to apply the room's retention policy if the event is a state event, as # MSC1763 states that retention is only considered for non-state events. - if not event.is_state(): + if apply_retention_policies and not event.is_state(): retention_policy = retention_policies[event.room_id] max_lifetime = retention_policy.get("max_lifetime") -- cgit 1.5.1 From 54dd5dc12b0ac5c48303144c4a73ce3822209488 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 3 Dec 2019 19:19:45 +0000 Subject: Add ephemeral messages support (MSC2228) (#6409) Implement part [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). The parts that differ are: * the feature is hidden behind a configuration flag (`enable_ephemeral_messages`) * self-destruction doesn't happen for state events * only implement support for the `m.self_destruct_after` field (not the `m.self_destruct` one) * doesn't send synthetic redactions to clients because for this specific case we consider the clients to be able to destroy an event themselves, instead we just censor it (by pruning its JSON) in the database --- changelog.d/6409.feature | 1 + synapse/api/constants.py | 4 + synapse/config/server.py | 2 + synapse/handlers/federation.py | 8 ++ synapse/handlers/message.py | 123 +++++++++++++++++++- synapse/storage/data_stores/main/events.py | 126 ++++++++++++++++++++- .../main/schema/delta/56/event_expiry.sql | 21 ++++ tests/rest/client/test_ephemeral_message.py | 101 +++++++++++++++++ 8 files changed, 379 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6409.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql create mode 100644 tests/rest/client/test_ephemeral_message.py (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6409.feature b/changelog.d/6409.feature new file mode 100644 index 0000000000..653ff5a5ad --- /dev/null +++ b/changelog.d/6409.feature @@ -0,0 +1 @@ +Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e3f086f1c3..69cef369a5 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -147,3 +147,7 @@ class EventContentFields(object): # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/config/server.py b/synapse/config/server.py index 7a9d711669..837fbe1582 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -490,6 +490,8 @@ class ServerConfig(Config): "cleanup_extremities_with_dummy_events", True ) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3267734f7..d9d0cd9eef 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -121,6 +121,7 @@ class FederationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() @@ -141,6 +142,8 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or @@ -2715,6 +2718,11 @@ class FederationHandler(BaseHandler): event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: yield self._notify_persisted_event(event, max_stream_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3b0156f516..4f53a5f5dc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -62,6 +70,17 @@ class MessageHandler(object): self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._is_worker_app = bool(hs.config.worker_app) + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -225,6 +244,100 @@ class MessageHandler(object): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + assert not self._is_worker_app + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + # The duration (in ms) after which rooms should be removed # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try @@ -295,6 +408,10 @@ class EventCreationHandler(object): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def create_event( self, @@ -877,6 +994,10 @@ class EventCreationHandler(object): event, context=context ) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 2737a1d3ae..79c91fe284 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -130,6 +130,8 @@ class EventsStore( if self.hs.config.redaction_retention_period is not None: hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -940,6 +942,12 @@ class EventsStore( txn, event.event_id, labels, event.room_id, event.depth ) + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1101,12 +1109,7 @@ class EventsStore( def _update_censor_txn(txn): for redaction_id, event_id, pruned_json in updates: if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) + self._censor_event_txn(txn, event_id, pruned_json) self._simple_update_one_txn( txn, @@ -1117,6 +1120,22 @@ class EventsStore( yield self.runInteraction("_update_censor_txn", _update_censor_txn) + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self._simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + @defer.inlineCallbacks def count_daily_messages(self): """ @@ -1957,6 +1976,101 @@ class EventsStore( ], ) + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self._simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json(prune_event_dict(event.get_dict())) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self._simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 0000000000..81a36a8b1d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 0000000000..5e9c07ebf3 --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body -- cgit 1.5.1 From 8ad8bcbed0bbd8a00ddfbe693b99785b72bb8ee2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Dec 2019 11:50:34 +0000 Subject: Pull out room_invite_state_types config option once. Pulling things out of config is currently surprisingly expensive. --- synapse/handlers/message.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4f53a5f5dc..54fa216d83 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -363,6 +363,8 @@ class EventCreationHandler(object): self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self.room_invite_state_types = self.hs.config.room_invite_state_types + self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users @@ -916,7 +918,7 @@ class EventCreationHandler(object): state_to_include_ids = [ e_id for k, e_id in iteritems(current_state_ids) - if k[0] in self.hs.config.room_invite_state_types + if k[0] in self.room_invite_state_types or k == (EventTypes.Member, event.sender) ] -- cgit 1.5.1 From fc316a4894912f49f5d0321e533aabca5624b0ba Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 11 Dec 2019 13:39:47 +0000 Subject: Prevent redacted events from appearing in message search (#6377) --- changelog.d/6377.bugfix | 1 + synapse/handlers/federation.py | 7 +- synapse/handlers/message.py | 5 +- synapse/state/__init__.py | 3 +- synapse/storage/data_stores/main/events_worker.py | 97 ++++++++++++++--------- synapse/storage/data_stores/main/search.py | 8 +- 6 files changed, 78 insertions(+), 43 deletions(-) create mode 100644 changelog.d/6377.bugfix (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6377.bugfix b/changelog.d/6377.bugfix new file mode 100644 index 0000000000..ccda96962f --- /dev/null +++ b/changelog.d/6377.bugfix @@ -0,0 +1 @@ +Prevent redacted events from being returned during message search. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 13865c470c..8f3c9d7702 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -63,6 +63,7 @@ from synapse.replication.http.federation import ( ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer @@ -423,7 +424,7 @@ class FederationHandler(BaseHandler): evs = yield self.store.get_events( list(state_map.values()), get_prev_content=False, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, ) event_map.update(evs) @@ -1000,7 +1001,9 @@ class FederationHandler(BaseHandler): forward_events = yield self.store.get_successor_events(list(extremities)) extremities_events = yield self.store.get_events( - forward_events, check_redacted=False, get_prev_content=False + forward_events, + redact_behaviour=EventRedactBehaviour.AS_IS, + get_prev_content=False, ) # We set `check_history_visibility_only` as we might otherwise get false diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 54fa216d83..bf9add7fe2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -875,7 +876,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -952,7 +953,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 139beef8ed..3e6d62eef1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,6 +32,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -645,7 +646,7 @@ class StateResolutionStore(object): return self.store.get_events( event_ids, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=allow_rejected, ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9ee117ce0f..2c9142814c 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,8 +19,10 @@ import itertools import logging import threading from collections import namedtuple +from typing import List, Optional from canonicaljson import json +from constantly import NamedConstant, Names from twisted.internet import defer @@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) @@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_event( self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, + event_id: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, ): """Get an event from the database by event_id. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_id: The event_id of the event to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if + allow_rejected: If True return rejected events. + allow_none: If True, return None if no event found, if False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. + check_room_id: if not None, check the room of the found event. If there is a mismatch, behave as per allow_none. Returns: @@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore): events = yield self.get_events_as_list( [event_id], - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True return rejected events. Returns: Deferred : Dict from event_id to event. """ events = yield self.get_events_as_list( event_ids, - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events_as_list( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database and return in a list in the same order as given by `event_ids` arg. Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True, return rejected events. Returns: Deferred[list[EventBase]]: List of events fetched from the database. The @@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore): # Update the cache to save doing the checks again. entry.event.internal_metadata.recheck_redaction = False - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event events.append(event) diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 4eec2fae5e..dfb46ee0f8 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} -- cgit 1.5.1 From fa780e9721c940479a72eed9877ccad4fef78160 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Dec 2019 10:32:02 +0000 Subject: Change EventContext to use the Storage class (#6564) --- changelog.d/6564.misc | 1 + synapse/api/auth.py | 2 +- synapse/events/snapshot.py | 36 +++++++++++++++----------- synapse/events/third_party_rules.py | 2 +- synapse/handlers/_base.py | 2 +- synapse/handlers/federation.py | 14 +++++----- synapse/handlers/message.py | 10 +++---- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 4 +-- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/replication/http/federation.py | 5 +++- synapse/replication/http/send_event.py | 3 ++- synapse/storage/data_stores/main/push_rule.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- tests/test_state.py | 28 ++++++++++---------- 15 files changed, 64 insertions(+), 53 deletions(-) create mode 100644 changelog.d/6564.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6564.misc b/changelog.d/6564.misc new file mode 100644 index 0000000000..f644f5868b --- /dev/null +++ b/changelog.d/6564.misc @@ -0,0 +1 @@ +Change `EventContext` to use the `Storage` class, in preparation for moving state database queries to a separate data store. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9fd52a8c77..abbc7079a3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -79,7 +79,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 64e898f40c..a44baea365 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -149,7 +149,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_ids = yield self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -167,12 +167,13 @@ class EventContext: } @staticmethod - def deserialize(store, input): + def deserialize(storage, input): """Converts a dict that was produced by `serialize` back into a EventContext. Args: - store (DataStore): Used to convert AS ID to AS object + storage (Storage): Used to convert AS ID to AS object and fetch + state. input (dict): A dict produced by `serialize` Returns: @@ -181,6 +182,7 @@ class EventContext: context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. + storage=storage, prev_state_id=input["prev_state_id"], event_type=input["event_type"], event_state_key=input["event_state_key"], @@ -193,7 +195,7 @@ class EventContext: app_service_id = input["app_service_id"] if app_service_id: - context.app_service = store.get_app_service_by_id(app_service_id) + context.app_service = storage.main.get_app_service_by_id(app_service_id) return context @@ -216,7 +218,7 @@ class EventContext: return self._state_group @defer.inlineCallbacks - def get_current_state_ids(self, store): + def get_current_state_ids(self): """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -234,11 +236,11 @@ class EventContext: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._current_state_ids @defer.inlineCallbacks - def get_prev_state_ids(self, store): + def get_prev_state_ids(self): """ Gets the room state map, excluding this event. @@ -250,7 +252,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -270,7 +272,7 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self, store): + def _ensure_fetched(self): return defer.succeed(None) @@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext): Attributes: + _storage (Storage) + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have been calculated. None if we haven't started calculating yet @@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext): that was replaced. """ + # This needs to have a default as we're inheriting + _storage = attr.ib(default=None) _prev_state_id = attr.ib(default=None) _event_type = attr.ib(default=None) _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self, store): + def _ensure_fetched(self): if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) + self._fetching_state_deferred = run_in_background(self._fill_out_state) return make_deferred_yieldable(self._fetching_state_deferred) @defer.inlineCallbacks - def _fill_out_state(self, store): + def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) + self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self.state_group + ) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 714a9b1579..86f7e5f8aa 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -53,7 +53,7 @@ class ThirdPartyEventRules(object): if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d15c6282fb..51413d910e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -134,7 +134,7 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state = yield self.store.get_events( list(current_state_ids.values()) ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 60bb00fc6a..05ae40dde7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -718,7 +718,7 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = await context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: @@ -1418,7 +1418,7 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield self.user_joined_room(user, event.room_id) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) @@ -1927,7 +1927,7 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -2336,12 +2336,12 @@ class FederationHandler(BaseHandler): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state_ids = dict(current_state_ids) current_state_ids.update(state_updates) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) @@ -2625,7 +2625,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( @@ -2673,7 +2673,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bf9add7fe2..4ad752205f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -515,7 +515,7 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( yield self.store.get_event(prev_event_id, allow_none=True) @@ -665,7 +665,7 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return @@ -914,7 +914,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() state_to_include_ids = [ e_id @@ -967,7 +967,7 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -989,7 +989,7 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d3a1a7b4a6..89c9118b26 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler): requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + old_room_state = yield tombstone_context.get_current_state_ids() # update any aliases yield self._move_aliases_to_new_room( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7b7270fc61..44c5e3239c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -193,7 +193,7 @@ class RoomMemberHandler(object): requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -601,7 +601,7 @@ class RoomMemberHandler(object): if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(prev_state_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7881780760..7d9f5a38d9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and @@ -304,7 +304,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 9af4e7e173..49a3251372 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_handler = hs.get_handlers().federation_handler @@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = EventContext.deserialize(self.store, event_payload["context"]) + context = EventContext.deserialize( + self.storage, event_payload["context"] + ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 9bafd60b14..84b92f16ad 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() @staticmethod @@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.storage, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 5ba13aa973..e2673ae073 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -244,7 +244,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 92e3b9c512..70ff5751b6 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/tests/test_state.py b/tests/test_state.py index 176535947a..e0aae06be4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} ) @@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_e.get_prev_state_ids() self.assertSetEqual( {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} ) @@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} ) @@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(current_state_ids.values()) @@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(prev_state_ids.values()) @@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) -- cgit 1.5.1 From 5a047816434e2ce2df8b80eb63a49c17dc3085fb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 15:31:09 +0000 Subject: rename get_prev_events_for_room to get_prev_events_and_hashes_for_room ... to make way for a new method which just returns the event ids --- synapse/handlers/message.py | 6 ++++-- synapse/handlers/room_member.py | 4 +++- synapse/storage/data_stores/main/event_federation.py | 5 +++-- tests/storage/test_event_federation.py | 4 ++-- 4 files changed, 12 insertions(+), 7 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4ad752205f..2695975a16 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -740,7 +740,7 @@ class EventCreationHandler(object): % (len(prev_events_and_hashes),) ) else: - prev_events_and_hashes = yield self.store.get_prev_events_for_room( + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( builder.room_id ) @@ -1042,7 +1042,9 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( + room_id + ) latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 44c5e3239c..91bb34cd55 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -370,7 +370,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id) + prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( + room_id + ) latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) current_state_ids = yield self.state_handler.get_current_state_ids( diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 1f517e8fad..266fc9715f 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -149,9 +149,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) @defer.inlineCallbacks - def get_prev_events_for_room(self, room_id): + def get_prev_events_and_hashes_for_room(self, room_id): """ - Gets a subset of the current forward extremities in the given room. + Gets a subset of the current forward extremities in the given room, + along with their depths and hashes. Limits the result to 10 extremities, so that we can avoid creating events which refer to hundreds of prev_events. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index eadfb90a22..3a68bf3274 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -26,7 +26,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() @defer.inlineCallbacks - def test_get_prev_events_for_room(self): + def test_get_prev_events_and_hashes_for_room(self): room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities @@ -64,7 +64,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): yield self.store.db.runInteraction("insert", insert_event, i) # this should get the last five and five others - r = yield self.store.get_prev_events_for_room(room_id) + r = yield self.store.get_prev_events_and_hashes_for_room(room_id) self.assertEqual(10, len(r)) for i in range(0, 5): el = r[i] -- cgit 1.5.1 From 15720092ac7a1af57dde7018a8872d93bbb9d36b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 16:09:24 +0000 Subject: replace get_prev_events_and_hashes_for_room with get_prev_events_for_room in create_new_client_event --- synapse/handlers/message.py | 12 ++------ .../storage/data_stores/main/event_federation.py | 35 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 9 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2695975a16..a1c289b24a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -739,17 +739,11 @@ class EventCreationHandler(object): "Attempting to create an event with %i prev_events" % (len(prev_events_and_hashes),) ) + prev_event_ids = [event_id for event_id, _, _ in prev_events_and_hashes] else: - prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( - builder.room_id - ) - - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in prev_events_and_hashes - ] + prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) - event = yield builder.build(prev_event_ids=[p for p, _ in prev_events]) + event = yield builder.build(prev_event_ids=prev_event_ids) context = yield self.state.compute_event_context(event) if requester: context.app_service = requester.app_service diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 266fc9715f..88e6489576 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -177,6 +177,41 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return res + def get_prev_events_for_room(self, room_id: str): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[List[str]]: the event ids of the forward extremites + + """ + + return self.db.runInteraction( + "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id + ) + + def _get_prev_events_for_room_txn(self, txn, room_id: str): + # we just use the 10 newest events. Older events will become + # prev_events of future events. + + sql = """ + SELECT e.event_id FROM event_forward_extremities AS f + INNER JOIN events AS e USING (event_id) + WHERE f.room_id = ? + ORDER BY e.depth DESC + LIMIT 10 + """ + + txn.execute(sql, (room_id,)) + + return [row[0] for row in txn] + def get_latest_event_ids_and_hashes_in_room(self, room_id): """ Gets the current forward extremities in the given room -- cgit 1.5.1 From 66ca914dc0290b16516cbb599dc4be06793963ed Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 16:16:09 +0000 Subject: Remove unused hashes and depths from create_new_client_event params --- synapse/handlers/message.py | 26 ++++++++++++++------------ synapse/types.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 12 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a1c289b24a..5415b0c9ee 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -48,7 +48,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, UserID, create_requester +from synapse.types import Collection, RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -497,10 +497,14 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id + prev_event_ids = ( + None + if prev_events_and_hashes is None + else [event_id for event_id, _, _ in prev_events_and_hashes] + ) + event, context = yield self.create_new_client_event( - builder=builder, - requester=requester, - prev_events_and_hashes=prev_events_and_hashes, + builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -714,7 +718,7 @@ class EventCreationHandler(object): @measure_func("create_new_client_event") @defer.inlineCallbacks def create_new_client_event( - self, builder, requester=None, prev_events_and_hashes=None + self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None ): """Create a new event for a local client @@ -723,10 +727,9 @@ class EventCreationHandler(object): requester (synapse.types.Requester|None): - prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + prev_event_ids: the forward extremities to use as the prev_events for the - new event. For each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. + new event. If None, they will be requested from the database. @@ -734,12 +737,11 @@ class EventCreationHandler(object): Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] """ - if prev_events_and_hashes is not None: - assert len(prev_events_and_hashes) <= 10, ( + if prev_event_ids is not None: + assert len(prev_event_ids) <= 10, ( "Attempting to create an event with %i prev_events" - % (len(prev_events_and_hashes),) + % (len(prev_event_ids),) ) - prev_event_ids = [event_id for event_id, _, _ in prev_events_and_hashes] else: prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) diff --git a/synapse/types.py b/synapse/types.py index aafc3ffe74..cd996c0b5a 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -15,6 +15,7 @@ # limitations under the License. import re import string +import sys from collections import namedtuple import attr @@ -23,6 +24,17 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import SynapseError +# define a version of typing.Collection that works on python 3.5 +if sys.version_info[:3] >= (3, 6, 0): + from typing import Collection +else: + from typing import Sized, Iterable, Container, TypeVar + + T_co = TypeVar("T_co", covariant=True) + + class Collection(Iterable[T_co], Container[T_co], Sized): + __slots__ = () + class Requester( namedtuple( -- cgit 1.5.1 From 3bef62488e5cff4dfb33454f2f2e18cc928f319b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 3 Jan 2020 16:19:55 +0000 Subject: Remove unused hashes and depths from create_event params --- synapse/handlers/message.py | 21 +++++---------------- synapse/handlers/room_member.py | 8 +++++++- tests/unittest.py | 6 +----- 3 files changed, 13 insertions(+), 22 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5415b0c9ee..8ea3aca2f4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -422,7 +422,7 @@ class EventCreationHandler(object): event_dict, token_id=None, txn_id=None, - prev_events_and_hashes=None, + prev_event_ids: Optional[Collection[str]] = None, require_consent=True, ): """ @@ -439,10 +439,9 @@ class EventCreationHandler(object): token_id (str) txn_id (str) - prev_events_and_hashes (list[(str, dict[str, str], int)]|None): + prev_event_ids: the forward extremities to use as the prev_events for the - new event. For each event, a tuple of (event_id, hashes, depth) - where *hashes* is a map from algorithm to hash. + new event. If None, they will be requested from the database. @@ -497,12 +496,6 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id - prev_event_ids = ( - None - if prev_events_and_hashes is None - else [event_id for event_id, _, _ in prev_events_and_hashes] - ) - event, context = yield self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) @@ -1038,11 +1031,7 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - prev_events_and_hashes = yield self.store.get_prev_events_and_hashes_for_room( - room_id - ) - - latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes) + latest_event_ids = yield self.store.get_prev_events_for_room(room_id) members = yield self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids @@ -1061,7 +1050,7 @@ class EventCreationHandler(object): "room_id": room_id, "sender": user_id, }, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=latest_event_ids, ) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 91bb34cd55..d550ba8ab4 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -164,6 +164,12 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" + prev_event_ids = ( + None + if prev_events_and_hashes is None + else [event_id for event_id, _, _ in prev_events_and_hashes] + ) + event, context = yield self.event_creation_handler.create_event( requester, { @@ -177,7 +183,7 @@ class RoomMemberHandler(object): }, token_id=requester.access_token_id, txn_id=txn_id, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, require_consent=require_consent, ) diff --git a/tests/unittest.py b/tests/unittest.py index b30b7d1718..07b50c0ccd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -522,10 +522,6 @@ class HomeserverTestCase(TestCase): secrets = self.hs.get_secrets() requester = Requester(user, None, False, None, None) - prev_events_and_hashes = None - if prev_event_ids: - prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] - event, context = self.get_success( event_creator.create_event( requester, @@ -535,7 +531,7 @@ class HomeserverTestCase(TestCase): "sender": user.to_string(), "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, }, - prev_events_and_hashes=prev_events_and_hashes, + prev_event_ids=prev_event_ids, ) ) -- cgit 1.5.1 From a8ce7aeb433e08f46306797a1252668c178a7825 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 28 Jan 2020 14:18:29 +0000 Subject: Pass room version object into event_auth.check and check_redaction (#6788) These are easier to work with than the strings and we normally have one around. This fixes `FederationHander._persist_auth_tree` which was passing a RoomVersion object into event_auth.check instead of a string. --- changelog.d/6788.misc | 1 + synapse/api/auth.py | 7 +++++-- synapse/event_auth.py | 34 +++++++++++++++++++++------------- synapse/handlers/federation.py | 18 +++++++++++------- synapse/handlers/message.py | 8 ++++++-- synapse/state/v1.py | 4 ++-- synapse/state/v2.py | 4 +++- tests/test_event_auth.py | 11 ++++------- 8 files changed, 53 insertions(+), 34 deletions(-) create mode 100644 changelog.d/6788.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6788.misc b/changelog.d/6788.misc new file mode 100644 index 0000000000..5537355bea --- /dev/null +++ b/changelog.d/6788.misc @@ -0,0 +1 @@ +Record room versions in the `rooms` table. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 2cbfab2569..8b1277ad02 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import is_threepid_reserved from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache @@ -77,15 +78,17 @@ class Auth(object): self._account_validity = hs.config.account_validity @defer.inlineCallbacks - def check_from_context(self, room_version, event, context, do_sig_check=True): + def check_from_context(self, room_version: str, event, context, do_sig_check=True): prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} + + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] event_auth.check( - room_version, event, auth_events=auth_events, do_sig_check=do_sig_check + room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check ) @defer.inlineCallbacks diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e3a1ba47a0..016d5678e5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,17 +24,27 @@ from unpaddedbase64 import decode_base64 from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, EventSizeError, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersion, +) from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) -def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True): +def check( + room_version_obj: RoomVersion, + event, + auth_events, + do_sig_check=True, + do_size_check=True, +): """ Checks if this event is correctly authed. Args: - room_version (str): the version of the room + room_version_obj: the version of the room event: the event being checked. auth_events (dict: event-key -> event): the existing room state. @@ -97,10 +108,11 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru 403, "Creation event's room_id domain does not match sender's" ) - room_version = event.content.get("room_version", "1") - if room_version not in KNOWN_ROOM_VERSIONS: + room_version_prop = event.content.get("room_version", "1") + if room_version_prop not in KNOWN_ROOM_VERSIONS: raise AuthError( - 403, "room appears to have unsupported version %s" % (room_version,) + 403, + "room appears to have unsupported version %s" % (room_version_prop,), ) # FIXME logger.debug("Allowing! %s", event) @@ -160,7 +172,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru _check_power_levels(event, auth_events) if event.type == EventTypes.Redaction: - check_redaction(room_version, event, auth_events) + check_redaction(room_version_obj, event, auth_events) logger.debug("Allowing! %s", event) @@ -386,7 +398,7 @@ def _can_send_event(event, auth_events): return True -def check_redaction(room_version, event, auth_events): +def check_redaction(room_version_obj: RoomVersion, event, auth_events): """Check whether the event sender is allowed to redact the target event. Returns: @@ -406,11 +418,7 @@ def check_redaction(room_version, event, auth_events): if user_level >= redact_level: return False - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - raise RuntimeError("Unrecognized room version %r" % (room_version,)) - - if v.event_format == EventFormatVersions.V1: + if room_version_obj.event_format == EventFormatVersions.V1: redacter_domain = get_domain_from_id(event.event_id) redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f824ee79a0..180f165a7a 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -47,7 +47,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event -from synapse.events import EventBase, room_version_to_event_format +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( @@ -1198,7 +1198,7 @@ class FederationHandler(BaseHandler): """ logger.debug("Joining %s to %s", joinee, room_id) - origin, event, room_version = yield self._make_and_verify_event( + origin, event, room_version_obj = yield self._make_and_verify_event( target_hosts, room_id, joinee, @@ -1227,7 +1227,7 @@ class FederationHandler(BaseHandler): except ValueError: pass - event_format_version = room_version_to_event_format(room_version.identifier) + event_format_version = room_version_obj.event_format ret = yield self.federation_client.send_join( target_hosts, event, event_format_version ) @@ -1251,14 +1251,14 @@ class FederationHandler(BaseHandler): room_id=room_id, room_creator_user_id="", is_public=False, - room_version=room_version, + room_version=room_version_obj, ) except Exception: # FIXME pass yield self._persist_auth_tree( - origin, auth_chain, state, event, room_version + origin, auth_chain, state, event, room_version_obj ) # Check whether this room is the result of an upgrade of a room we already know @@ -2022,6 +2022,7 @@ class FederationHandler(BaseHandler): if do_soft_fail_check: room_version = yield self.store.get_room_version(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". if state is not None: @@ -2071,7 +2072,9 @@ class FederationHandler(BaseHandler): } try: - event_auth.check(room_version, event, auth_events=current_auth_events) + event_auth.check( + room_version_obj, event, auth_events=current_auth_events + ) except AuthError as e: logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @@ -2155,6 +2158,7 @@ class FederationHandler(BaseHandler): defer.Deferred[EventContext]: updated context object """ room_version = yield self.store.get_room_version(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] try: context = yield self._update_auth_events_and_context_for_auth( @@ -2172,7 +2176,7 @@ class FederationHandler(BaseHandler): ) try: - event_auth.check(room_version, event, auth_events=auth_events) + event_auth.check(room_version_obj, event, auth_events=auth_events) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8ea3aca2f4..9a0f661b9b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -40,7 +40,7 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background @@ -962,9 +962,13 @@ class EventCreationHandler(object): ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + room_version = yield self.store.get_room_version(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if event_auth.check_redaction(room_version, event, auth_events=auth_events): + if event_auth.check_redaction( + room_version_obj, event, auth_events=auth_events + ): # this user doesn't have 'redact' rights, so we need to do some more # checks on the original event. Let's start by checking the original # event exists. diff --git a/synapse/state/v1.py b/synapse/state/v1.py index d6c34ce3b7..24b7c0faef 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -281,7 +281,7 @@ def _resolve_auth_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, @@ -299,7 +299,7 @@ def _resolve_normal_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, event, auth_events, do_sig_check=False, diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 6216fdd204..531018c6a5 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -26,6 +26,7 @@ import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap @@ -402,6 +403,7 @@ def _iterative_auth_checks( Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for event_id in event_ids: event = event_map[event_id] @@ -430,7 +432,7 @@ def _iterative_auth_checks( try: event_auth.check( - room_version, + room_version_obj, event, auth_events, do_sig_check=False, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 8b2741d277..ca20b085a2 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -37,7 +37,7 @@ class EventAuthTestCase(unittest.TestCase): # creator should be able to send state event_auth.check( - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(creator), auth_events, do_sig_check=False, @@ -47,7 +47,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(joiner), auth_events, do_sig_check=False, @@ -76,7 +76,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1.identifier, + RoomVersions.V1, _random_state_event(pleb), auth_events, do_sig_check=False, @@ -84,10 +84,7 @@ class EventAuthTestCase(unittest.TestCase): # king should be able to send state event_auth.check( - RoomVersions.V1.identifier, - _random_state_event(king), - auth_events, - do_sig_check=False, + RoomVersions.V1, _random_state_event(king), auth_events, do_sig_check=False, ) -- cgit 1.5.1 From d7bf793cc1e3f5268285286341835ac54753eff6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 31 Jan 2020 10:06:21 +0000 Subject: s/get_room_version/get_room_version_id/ ... to make way for a forthcoming get_room_version which returns a RoomVersion object. --- synapse/federation/federation_client.py | 10 +++++----- synapse/federation/federation_server.py | 16 ++++++++-------- synapse/handlers/federation.py | 18 +++++++++--------- synapse/handlers/message.py | 8 +++++--- synapse/handlers/pagination.py | 2 +- synapse/handlers/room.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/data_stores/main/state.py | 2 +- synapse/storage/persist_events.py | 2 +- tests/handlers/test_presence.py | 2 +- tests/test_state.py | 2 +- tests/unittest.py | 4 +++- 12 files changed, 37 insertions(+), 33 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d57e8ca7a2..4ac3d81cba 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -198,7 +198,7 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdus = [ @@ -336,7 +336,7 @@ class FederationClient(FederationBase): def get_event_auth(self, destination, room_id, event_id): res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ @@ -649,7 +649,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) content = yield self._do_send_invite(destination, pdu, room_version) @@ -657,7 +657,7 @@ class FederationClient(FederationBase): logger.debug("Got response to send_invite: %s", pdu_dict) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(pdu_dict, format_ver) @@ -859,7 +859,7 @@ class FederationClient(FederationBase): timeout=timeout, ) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) events = [ diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9562faa3ee..a4c97ed458 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -234,7 +234,7 @@ class FederationServer(FederationBase): continue try: - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -334,7 +334,7 @@ class FederationServer(FederationBase): ) ) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) resp["room_version"] = room_version return 200, resp @@ -385,7 +385,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) if room_version not in supported_versions: logger.warning( "Room version %s not in %s", room_version, supported_versions @@ -417,7 +417,7 @@ class FederationServer(FederationBase): async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) @@ -440,7 +440,7 @@ class FederationServer(FederationBase): await self.check_server_matches_acl(origin_host, room_id) pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} @@ -448,7 +448,7 @@ class FederationServer(FederationBase): async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) @@ -495,7 +495,7 @@ class FederationServer(FederationBase): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ @@ -664,7 +664,7 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = await self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version_id(pdu.room_id) # Check signature. try: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 01372f6d47..30c720f093 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -388,7 +388,7 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) state_map = await resolve_events_with_store( room_id, room_version, @@ -1110,7 +1110,7 @@ class FederationHandler(BaseHandler): Logs a warning if we can't find the given event. """ - room_version = await self.store.get_room_version(room_id) + room_version = await self.store.get_room_version_id(room_id) event_infos = [] @@ -1373,7 +1373,7 @@ class FederationHandler(BaseHandler): event_content = {"membership": Membership.JOIN} - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new( room_version, @@ -1607,7 +1607,7 @@ class FederationHandler(BaseHandler): ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new( room_version, { @@ -2055,7 +2055,7 @@ class FederationHandler(BaseHandler): do_soft_fail_check = False if do_soft_fail_check: - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". @@ -2191,7 +2191,7 @@ class FederationHandler(BaseHandler): Returns: defer.Deferred[EventContext]: updated context object """ - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] try: @@ -2363,7 +2363,7 @@ class FederationHandler(BaseHandler): remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) remote_state = remote_auth_events.values() - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) new_state = yield self.state_handler.resolve_events( room_version, (local_state, remote_state), event ) @@ -2587,7 +2587,7 @@ class FederationHandler(BaseHandler): } if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) @@ -2650,7 +2650,7 @@ class FederationHandler(BaseHandler): Returns: Deferred: resolves (to None) """ - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) # NB: event_dict has a particular specced format we might need to fudge # if we change event formats too much. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9a0f661b9b..bdf16c84d3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -459,7 +459,9 @@ class EventCreationHandler(object): room_version = event_dict["content"]["room_version"] else: try: - room_version = yield self.store.get_room_version(event_dict["room_id"]) + room_version = yield self.store.get_room_version_id( + event_dict["room_id"] + ) except NotFoundError: raise AuthError(403, "Unknown room") @@ -788,7 +790,7 @@ class EventCreationHandler(object): ): room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) event_allowed = yield self.third_party_event_rules.check_event_allowed( event, context @@ -963,7 +965,7 @@ class EventCreationHandler(object): auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} - room_version = yield self.store.get_room_version(event.room_id) + room_version = yield self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if event_auth.check_redaction( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 71d76202c9..caf841a643 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -281,7 +281,7 @@ class PaginationHandler(object): """Purge the given room from the database""" with (await self.pagination_lock.write(room_id)): # check we know about the room - await self.store.get_room_version(room_id) + await self.store.get_room_version_id(room_id) # first check that we have no users in this room joined = await defer.maybeDeferred( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a95b45d791..1382399557 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -178,7 +178,7 @@ class RoomCreationHandler(BaseHandler): }, token_id=requester.access_token_id, ) - old_room_version = yield self.store.get_room_version(old_room_id) + old_room_version = yield self.store.get_room_version_id(old_room_id) yield self.auth.check_from_context( old_room_version, tombstone_event, tombstone_context ) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index cacd0c0c2b..fdd6bef6b4 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -394,7 +394,7 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version(room_id) + room_version = yield self.store.get_room_version_id(room_id) result = yield self._state_resolution_handler.resolve_state_groups( room_id, diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 4167f83c9b..6700942523 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -62,7 +62,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) @cached(max_entries=10000) - async def get_room_version(self, room_id: str) -> str: + async def get_room_version_id(self, room_id: str) -> str: """Get the room_version of a given room Raises: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 86166fd4c1..af3fd67ab9 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -661,7 +661,7 @@ class EventsPersistenceStorage(object): break if not room_version: - room_version = await self.main_store.get_room_version(room_id) + room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") res = await self._state_resolution_handler.resolve_state_groups( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index d4293b4312..e92e090c3c 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -588,7 +588,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): hostname = get_domain_from_id(user_id) - room_version = self.get_success(self.store.get_room_version(room_id)) + room_version = self.get_success(self.store.get_room_version_id(room_id)) builder = EventBuilder( state=self.state, diff --git a/tests/test_state.py b/tests/test_state.py index e0aae06be4..1e4449fa1c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -119,7 +119,7 @@ class StateGroupStore(object): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group - def get_room_version(self, room_id): + def get_room_version_id(self, room_id): return RoomVersions.V1.identifier diff --git a/tests/unittest.py b/tests/unittest.py index b56e249386..98bf27d39c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -589,7 +589,9 @@ class HomeserverTestCase(TestCase): event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - room_version = self.get_success(self.hs.get_datastore().get_room_version(room)) + room_version = self.get_success( + self.hs.get_datastore().get_room_version_id(room) + ) builder = event_builder_factory.for_room_version( KNOWN_ROOM_VERSIONS[room_version], -- cgit 1.5.1 From 5d17c3159618b6e75bb58ba68a77a73572a85688 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 3 Feb 2020 22:28:11 +0000 Subject: make FederationHandler.send_invite async --- synapse/handlers/federation.py | 5 ++--- synapse/handlers/message.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ea2f6a91d7..5728ea2ee7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1184,13 +1184,12 @@ class FederationHandler(BaseHandler): ) raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") - @defer.inlineCallbacks - def send_invite(self, target_host, event): + async def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. """ - pdu = yield self.federation_client.send_invite( + pdu = await self.federation_client.send_invite( destination=target_host, room_id=event.room_id, event_id=event.event_id, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bdf16c84d3..be6ae18a92 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -932,10 +932,9 @@ class EventCreationHandler(object): # way? If we have been invited by a remote server, we need # to get them to sign the event. - returned_invite = yield federation_handler.send_invite( - invitee.domain, event + returned_invite = yield defer.ensureDeferred( + federation_handler.send_invite(invitee.domain, event) ) - event.unsigned.pop("room_state", None) # TODO: Make sure the signatures actually are correct. -- cgit 1.5.1 From a0a1fd0bec5cb596cc41c8f052a4aa0e8c01cf08 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 18 Feb 2020 23:14:57 +0000 Subject: Add `allow_departed_users` param to `check_in_room_or_world_readable` ... and set it everywhere it's called. while we're here, rename it for consistency with `check_user_in_room` (and to help check that I haven't missed any instances) --- synapse/api/auth.py | 16 +++++++++++++--- synapse/handlers/initial_sync.py | 4 +++- synapse/handlers/message.py | 12 ++++++++---- synapse/handlers/pagination.py | 4 +++- synapse/rest/client/v2_alpha/relations.py | 12 ++++++------ 5 files changed, 33 insertions(+), 15 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index de7b75ca36..f576d65388 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -625,10 +625,18 @@ class Auth(object): return query_params[0].decode("ascii") @defer.inlineCallbacks - def check_in_room_or_world_readable(self, room_id, user_id): + def check_user_in_room_or_world_readable( + self, room_id: str, user_id: str, allow_departed_users: bool = False + ): """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. + Args: + room_id: room to check + user_id: user to check + allow_departed_users: if True, accept users that were previously + members but have now departed + Returns: Deferred[tuple[str, str|None]]: Resolves to the current membership of the user in the room and the membership event ID of the user. If @@ -643,7 +651,7 @@ class Auth(object): # * The user is a guest user, and has joined the room # else it will throw. member_event = yield self.check_user_in_room( - room_id, user_id, allow_departed_users=True + room_id, user_id, allow_departed_users=allow_departed_users ) return member_event.membership, member_event.event_id except AuthError: @@ -656,7 +664,9 @@ class Auth(object): ): return Membership.JOIN, None raise AuthError( - 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + 403, + "User %s not in room %s, and room previews are disabled" + % (user_id, room_id), ) @defer.inlineCallbacks diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index b7c6a921d9..b116500c7d 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -277,7 +277,9 @@ class InitialSyncHandler(BaseHandler): ( membership, member_event_id, - ) = await self.auth.check_user_in_room_or_world_readable(room_id, user_id) + ) = await self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True, + ) is_peeking = member_event_id is None if membership == Membership.JOIN: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index be6ae18a92..d6be280952 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -99,7 +99,9 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) + ) = yield self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) if membership == Membership.JOIN: data = yield self.state.get_current_state(room_id, event_type, state_key) @@ -177,7 +179,9 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) + ) = yield self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) if membership == Membership.JOIN: state_ids = yield self.store.get_filtered_current_state_ids( @@ -216,8 +220,8 @@ class MessageHandler(object): if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. - membership, _ = yield self.auth.check_in_room_or_world_readable( - room_id, user_id + membership, _ = yield self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True ) if membership != Membership.JOIN: raise NotImplementedError( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index caf841a643..254a9f6856 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -335,7 +335,9 @@ class PaginationHandler(object): ( membership, member_event_id, - ) = await self.auth.check_in_room_or_world_readable(room_id, user_id) + ) = await self.auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 9be9a34b91..63f07b63da 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -142,8 +142,8 @@ class RelationPaginationServlet(RestServlet): ): requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True ) # This gets the original event and checks that a) the event exists and @@ -235,8 +235,8 @@ class RelationAggregationPaginationServlet(RestServlet): ): requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to @@ -313,8 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet): async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_in_room_or_world_readable( - room_id, requester.user.to_string() + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to -- cgit 1.5.1 From 1f773eec912e4908ab60f7823f5c0a024261af4d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 Feb 2020 15:33:26 +0000 Subject: Port PresenceHandler to async/await (#6991) --- changelog.d/6991.misc | 1 + synapse/handlers/message.py | 5 +- synapse/handlers/presence.py | 192 ++++++++++++++++-------------------- synapse/replication/tcp/resource.py | 6 +- synapse/server.pyi | 5 + tests/handlers/test_presence.py | 18 ++-- tox.ini | 1 + 7 files changed, 113 insertions(+), 115 deletions(-) create mode 100644 changelog.d/6991.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6991.misc b/changelog.d/6991.misc new file mode 100644 index 0000000000..5130f4e8af --- /dev/null +++ b/changelog.d/6991.misc @@ -0,0 +1 @@ +Port `synapse.handlers.presence` to async/await. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be280952..a0103addd3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1016,11 +1016,10 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - @defer.inlineCallbacks - def _bump_active_time(self, user): + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0d6cf2b008..5526015ddb 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -24,11 +24,12 @@ The methods that define policy are: import logging from contextlib import contextmanager -from typing import Dict, Set +from typing import Dict, List, Set from six import iteritems, itervalues from prometheus_client import Counter +from typing_extensions import ContextManager from twisted.internet import defer @@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs - self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname self.clock = hs.get_clock() @@ -150,7 +154,7 @@ class PresenceHandler(object): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() + self.unpersisted_users_changes = set() # type: Set[str] hs.get_reactor().addSystemEventTrigger( "before", @@ -160,12 +164,11 @@ class PresenceHandler(object): self._on_shutdown, ) - self.serial_to_user = {} self._next_serial = 1 # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} + self.user_to_num_current_syncs = {} # type: Dict[str, int] # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -213,8 +216,7 @@ class PresenceHandler(object): self._event_pos = self.store.get_current_events_token() self._event_processing = False - @defer.inlineCallbacks - def _on_shutdown(self): + async def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -235,7 +237,7 @@ class PresenceHandler(object): if self.unpersisted_users_changes: - yield self.store.update_presence( + await self.store.update_presence( [ self.user_to_current_state[user_id] for user_id in self.unpersisted_users_changes @@ -243,8 +245,7 @@ class PresenceHandler(object): ) logger.info("Finished _on_shutdown") - @defer.inlineCallbacks - def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self): """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -253,12 +254,11 @@ class PresenceHandler(object): if unpersisted: logger.info("Persisting %d unpersisted presence updates", len(unpersisted)) - yield self.store.update_presence( + await self.store.update_presence( [self.user_to_current_state[user_id] for user_id in unpersisted] ) - @defer.inlineCallbacks - def _update_states(self, new_states): + async def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. @@ -267,7 +267,7 @@ class PresenceHandler(object): with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't yield between now and when we've + # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. to_notify = {} # Changes we want to notify everyone about @@ -311,7 +311,7 @@ class PresenceHandler(object): if to_notify: notified_presence_counter.inc(len(to_notify)) - yield self._persist_and_notify(list(to_notify.values())) + await self._persist_and_notify(list(to_notify.values())) self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) @@ -326,7 +326,7 @@ class PresenceHandler(object): self._push_to_remotes(to_federation_ping.values()) - def _handle_timeouts(self): + async def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as appropriate. """ @@ -368,10 +368,9 @@ class PresenceHandler(object): now=now, ) - return self._update_states(changes) + return await self._update_states(changes) - @defer.inlineCallbacks - def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user): """We've seen the user do something that indicates they're interacting with the app. """ @@ -383,16 +382,17 @@ class PresenceHandler(object): bump_active_time_counter.inc() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def user_syncing(self, user_id, affect_presence=True): + async def user_syncing( + self, user_id: str, affect_presence: bool = True + ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -415,11 +415,11 @@ class PresenceHandler(object): curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( state=PresenceState.ONLINE, @@ -429,7 +429,7 @@ class PresenceHandler(object): ] ) else: - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -437,13 +437,12 @@ class PresenceHandler(object): ] ) - @defer.inlineCallbacks - def _end(): + async def _end(): try: self.user_to_num_current_syncs[user_id] -= 1 - prev_state = yield self.current_state_for_user(user_id) - yield self._update_states( + prev_state = await self.current_state_for_user(user_id) + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -480,8 +479,7 @@ class PresenceHandler(object): else: return set() - @defer.inlineCallbacks - def update_external_syncs_row( + async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec ): """Update the syncing users for an external process as a delta. @@ -494,8 +492,8 @@ class PresenceHandler(object): is_syncing (bool): Whether or not the user is now syncing sync_time_msec(int): Time in ms when the user was last syncing """ - with (yield self.external_sync_linearizer.queue(process_id)): - prev_state = yield self.current_state_for_user(user_id) + with (await self.external_sync_linearizer.queue(process_id)): + prev_state = await self.current_state_for_user(user_id) process_presence = self.external_process_to_current_syncs.setdefault( process_id, set() @@ -525,25 +523,24 @@ class PresenceHandler(object): process_presence.discard(user_id) if updates: - yield self._update_states(updates) + await self._update_states(updates) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - @defer.inlineCallbacks - def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id): """Marks all users that had been marked as syncing by a given process as offline. Used when the process has stopped/disappeared. """ - with (yield self.external_sync_linearizer.queue(process_id)): + with (await self.external_sync_linearizer.queue(process_id)): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = yield self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) for prev_state in itervalues(prev_states) @@ -551,15 +548,13 @@ class PresenceHandler(object): ) self.external_process_last_updated_ms.pop(process_id, None) - @defer.inlineCallbacks - def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id): """Get the current presence state for a user. """ - res = yield self.current_state_for_users([user_id]) + res = await self.current_state_for_users([user_id]) return res[user_id] - @defer.inlineCallbacks - def current_state_for_users(self, user_ids): + async def current_state_for_users(self, user_ids): """Get the current presence state for multiple users. Returns: @@ -574,7 +569,7 @@ class PresenceHandler(object): if missing: # There are things not in our in memory cache. Lets pull them out of # the database. - res = yield self.store.get_presence_for_users(missing) + res = await self.store.get_presence_for_users(missing) states.update(res) missing = [user_id for user_id, state in iteritems(states) if not state] @@ -587,14 +582,13 @@ class PresenceHandler(object): return states - @defer.inlineCallbacks - def _persist_and_notify(self, states): + async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers """ - stream_id, max_token = yield self.store.update_presence(states) + stream_id, max_token = await self.store.update_presence(states) - parties = yield get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -606,9 +600,8 @@ class PresenceHandler(object): self._push_to_remotes(states) - @defer.inlineCallbacks - def notify_for_states(self, state, stream_id): - parties = yield get_interested_parties(self.store, [state]) + async def notify_for_states(self, state, stream_id): + parties = await get_interested_parties(self.store, [state]) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -626,8 +619,7 @@ class PresenceHandler(object): """ self.federation.send_presence(states) - @defer.inlineCallbacks - def incoming_presence(self, origin, content): + async def incoming_presence(self, origin, content): """Called when we receive a `m.presence` EDU from a remote server. """ now = self.clock.time_msec() @@ -670,21 +662,19 @@ class PresenceHandler(object): new_fields["status_msg"] = push.get("status_msg", None) new_fields["currently_active"] = push.get("currently_active", False) - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) updates.append(prev_state.copy_and_replace(**new_fields)) if updates: federation_presence_counter.inc(len(updates)) - yield self._update_states(updates) + await self._update_states(updates) - @defer.inlineCallbacks - def get_state(self, target_user, as_event=False): - results = yield self.get_states([target_user.to_string()], as_event=as_event) + async def get_state(self, target_user, as_event=False): + results = await self.get_states([target_user.to_string()], as_event=as_event) return results[0] - @defer.inlineCallbacks - def get_states(self, target_user_ids, as_event=False): + async def get_states(self, target_user_ids, as_event=False): """Get the presence state for users. Args: @@ -695,7 +685,7 @@ class PresenceHandler(object): list """ - updates = yield self.current_state_for_users(target_user_ids) + updates = await self.current_state_for_users(target_user_ids) updates = list(updates.values()) for user_id in set(target_user_ids) - {u.user_id for u in updates}: @@ -713,8 +703,7 @@ class PresenceHandler(object): else: return updates - @defer.inlineCallbacks - def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -730,7 +719,7 @@ class PresenceHandler(object): user_id = target_user.to_string() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"state": presence} @@ -741,16 +730,15 @@ class PresenceHandler(object): if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. """ - observer_room_ids = yield self.store.get_rooms_for_user( + observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) - observed_room_ids = yield self.store.get_rooms_for_user( + observed_room_ids = await self.store.get_rooms_for_user( observed_user.to_string() ) @@ -759,8 +747,7 @@ class PresenceHandler(object): return False - @defer.inlineCallbacks - def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -775,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = yield self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id) return rows def notify_new_event(self): @@ -786,20 +773,18 @@ class PresenceHandler(object): if self._event_processing: return - @defer.inlineCallbacks - def _process_presence(): + async def _process_presence(): assert not self._event_processing self._event_processing = True try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._event_processing = False run_as_background_process("presence.notify_new_event", _process_presence) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): @@ -812,10 +797,10 @@ class PresenceHandler(object): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) - yield self._handle_state_delta(deltas) + await self._handle_state_delta(deltas) self._event_pos = max_pos @@ -824,8 +809,7 @@ class PresenceHandler(object): max_pos ) - @defer.inlineCallbacks - def _handle_state_delta(self, deltas): + async def _handle_state_delta(self, deltas): """Process current state deltas to find new joins that need to be handled. """ @@ -846,13 +830,13 @@ class PresenceHandler(object): # joins. continue - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event or event.content.get("membership") != Membership.JOIN: # We only care about joins continue if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if ( prev_event and prev_event.content.get("membership") == Membership.JOIN @@ -860,10 +844,9 @@ class PresenceHandler(object): # Ignore changes to join events. continue - yield self._on_user_joined_room(room_id, state_key) + await self._on_user_joined_room(room_id, state_key) - @defer.inlineCallbacks - def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id, user_id): """Called when we detect a user joining the room via the current state delta stream. @@ -882,8 +865,8 @@ class PresenceHandler(object): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = yield self.current_state_for_user(user_id) - hosts = yield self.state.get_current_hosts_in_room(room_id) + state = await self.current_state_for_user(user_id) + hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. hosts = {host for host in hosts if host != self.server_name} @@ -903,10 +886,10 @@ class PresenceHandler(object): # TODO: Check that this is actually a new server joining the # room. - user_ids = yield self.state.get_current_users_in_room(room_id) + user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = yield self.current_state_for_users(user_ids) + states = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -996,9 +979,8 @@ class PresenceEventSource(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - @defer.inlineCallbacks @log_function - def get_new_events( + async def get_new_events( self, user, from_key, @@ -1045,7 +1027,7 @@ class PresenceEventSource(object): presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - users_interested_in = yield self._get_interested_in(user, explicit_room_id) + users_interested_in = await self._get_interested_in(user, explicit_room_id) user_ids_changed = set() changed = None @@ -1071,7 +1053,7 @@ class PresenceEventSource(object): else: user_ids_changed = users_interested_in - updates = yield presence.current_state_for_users(user_ids_changed) + updates = await presence.current_state_for_users(user_ids_changed) if include_offline: return (list(updates.values()), max_token) @@ -1084,11 +1066,11 @@ class PresenceEventSource(object): def get_current_key(self): return self.store.get_current_presence_token() - def get_pagination_rows(self, user, pagination_config, key): - return self.get_new_events(user, from_key=None, include_offline=False) + async def get_pagination_rows(self, user, pagination_config, key): + return await self.get_new_events(user, from_key=None, include_offline=False) - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _get_interested_in(self, user, explicit_room_id, cache_context): + @cached(num_args=2, cache_context=True) + async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence updates for """ @@ -1096,13 +1078,13 @@ class PresenceEventSource(object): users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: - user_ids = yield self.store.get_users_in_room( + user_ids = await self.store.get_users_in_room( explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1277,8 +1259,8 @@ def get_interested_parties(store, states): 2-tuple: `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} - users_to_states = {} + room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] + users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: room_ids = yield store.get_rooms_for_user(state.user_id) for room_id in room_ids: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce60ae2e07..ce9d1fae12 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -323,7 +323,11 @@ class ReplicationStreamer(object): # We need to tell the presence handler that the connection has been # lost so that it can handle any ongoing syncs on that connection. - self.presence_handler.update_external_syncs_clear(connection.conn_id) + run_as_background_process( + "update_external_syncs_clear", + self.presence_handler.update_external_syncs_clear, + connection.conn_id, + ) def _batch_updates(updates): diff --git a/synapse/server.pyi b/synapse/server.pyi index 40eabfe5d9..3844f0e12f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -3,6 +3,7 @@ import twisted.internet import synapse.api.auth import synapse.config.homeserver import synapse.crypto.keyring +import synapse.federation.federation_server import synapse.federation.sender import synapse.federation.transport.client import synapse.handlers @@ -107,5 +108,9 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def get_federation_registry( + self, + ) -> synapse.federation.federation_server.FederationHandlerRegistry: + pass def is_mine_id(self, domain_id: str) -> bool: pass diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 64915bafcd..05ea40a7de 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, "@test2:server") # Mark test2 as online, test will be offline with a last_active of 0 - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) # Mark test as online - self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + ) ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) # Add servers to the room diff --git a/tox.ini b/tox.ini index b715ea0bff..4ccfde01b5 100644 --- a/tox.ini +++ b/tox.ini @@ -183,6 +183,7 @@ commands = mypy \ synapse/events/spamcheck.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ synapse/logging/ \ -- cgit 1.5.1 From 7dcbc33a1be04c46b930699c03c15bc759f4b22c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 3 Mar 2020 07:12:45 -0500 Subject: Validate the alt_aliases property of canonical alias events (#6971) --- changelog.d/6971.feature | 1 + synapse/api/errors.py | 1 + synapse/handlers/directory.py | 14 ++-- synapse/handlers/message.py | 47 ++++++++++- synapse/types.py | 15 ++-- tests/handlers/test_directory.py | 66 +++++++-------- tests/rest/client/v1/test_rooms.py | 160 +++++++++++++++++++++++++++++++++++++ tests/test_types.py | 2 +- 8 files changed, 254 insertions(+), 52 deletions(-) create mode 100644 changelog.d/6971.feature (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/6971.feature b/changelog.d/6971.feature new file mode 100644 index 0000000000..ccf02a61df --- /dev/null +++ b/changelog.d/6971.feature @@ -0,0 +1 @@ +Validate the alt_aliases property of canonical alias events. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0c20601600..616942b057 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -66,6 +66,7 @@ class Codes(object): EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + BAD_ALIAS = "M_BAD_ALIAS" class CodeMessageException(RuntimeError): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 0b23ca919a..61eb49059b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import collections import logging import string from typing import List @@ -307,15 +305,17 @@ class DirectoryHandler(BaseHandler): send_update = True content.pop("alias", "") - # Filter alt_aliases for the removed alias. - alt_aliases = content.pop("alt_aliases", None) - # If the aliases are not a list (or not found) do not attempt to modify - # the list. - if isinstance(alt_aliases, collections.Sequence): + # Filter the alt_aliases property for the removed alias. Note that the + # value is not modified if alt_aliases is of an unexpected form. + alt_aliases = content.get("alt_aliases") + if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases: send_update = True alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + if alt_aliases: content["alt_aliases"] = alt_aliases + else: + del content["alt_aliases"] if send_update: yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a0103addd3..0c84c6cec4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -888,19 +888,60 @@ class EventCreationHandler(object): yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) + # Validate a newly added alias or newly added alt_aliases. + + original_alias = None + original_alt_aliases = set() + + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_event = yield self.store.get_event(original_event_id) + + if original_event: + original_alias = original_event.content.get("alias", None) + original_alt_aliases = original_event.content.get("alt_aliases", []) + + # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - if room_alias_str: + directory_handler = self.hs.get_handlers().directory_handler + if room_alias_str and room_alias_str != original_alias: room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) if mapping["room_id"] != event.room_id: raise SynapseError( 400, "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM + ) + + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + room_alias = RoomAlias.from_string(alias_str) + mapping = yield directory_handler.get_association(room_alias) + + if mapping["room_id"] != event.room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" + % (room_alias_str,), + Codes.BAD_ALIAS, + ) + federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: diff --git a/synapse/types.py b/synapse/types.py index f3cd465735..acf60baddc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -23,7 +23,7 @@ import attr from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): @@ -166,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str): """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( - 400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL) + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) parts = s[1:].split(":", 1) @@ -179,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 400, "Expected %s of the form '%slocalname:domain'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) domain = parts[1] @@ -235,11 +238,13 @@ class GroupID(DomainSpecificString): def from_string(cls, s): group_id = super(GroupID, cls).from_string(s) if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty") + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) if contains_invalid_mxid_characters(group_id.localpart): raise SynapseError( - 400, "Group ID can only contain characters a-z, 0-9, or '=_-./'" + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + Codes.INVALID_PARAM, ) return group_id diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 27b916aed4..3397cfa485 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -88,6 +88,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias_not_allowed(self): + """Removing an alias should be denied if a user does not have the proper permissions.""" room_id = "!8765qwer:test" self.get_success( self.store.create_room_alias_association(self.my_room, room_id, ["test"]) @@ -101,6 +102,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias(self): + """Removing an alias should work when a user does has the proper permissions.""" room_id = "!8765qwer:test" user_id = "@user:test" self.get_success( @@ -159,30 +161,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) self.test_alias = "#test:test" - self.room_alias = RoomAlias.from_string(self.test_alias) + self.room_alias = self._add_alias(self.test_alias) + + def _add_alias(self, alias: str) -> RoomAlias: + """Add an alias to the test room.""" + room_alias = RoomAlias.from_string(alias) # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( - self.room_alias, self.room_id, ["test"], self.admin_user + room_alias, self.room_id, ["test"], self.admin_user ) ) + return room_alias - def test_remove_alias(self): - """Removing an alias that is the canonical alias should remove it there too.""" - # Set this new alias as the canonical alias for this room + def _set_canonical_alias(self, content): + """Configure the canonical alias state on the room.""" self.helper.send_state( - self.room_id, - "m.room.canonical_alias", - {"alias": self.test_alias, "alt_aliases": [self.test_alias]}, - tok=self.admin_user_tok, + self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, ) - data = self.get_success( + def _get_canonical_alias(self): + """Get the canonical alias state of the room.""" + return self.get_success( self.state_handler.get_current_state( self.room_id, EventTypes.CanonicalAlias, "" ) ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self._set_canonical_alias( + {"alias": self.test_alias, "alt_aliases": [self.test_alias]} + ) + + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) @@ -193,11 +207,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) @@ -205,29 +215,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" - other_room_alias = RoomAlias.from_string(other_test_alias) - self.get_success( - self.store.create_room_alias_association( - other_room_alias, self.room_id, ["test"], self.admin_user - ) - ) + other_room_alias = self._add_alias(other_test_alias) # Set the alias as the canonical alias for this room. - self.helper.send_state( - self.room_id, - "m.room.canonical_alias", + self._set_canonical_alias( { "alias": self.test_alias, "alt_aliases": [self.test_alias, other_test_alias], - }, - tok=self.admin_user_tok, + } ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual( data["content"]["alt_aliases"], [self.test_alias, other_test_alias] @@ -240,11 +238,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2f3df5f88f..7dd86d0c27 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/test_types.py b/tests/test_types.py index 8d97c751ea..480bea1bdc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase): self.fail("Parsing '%s' should raise exception" % id_string) except SynapseError as exc: self.assertEqual(400, exc.code) - self.assertEqual("M_UNKNOWN", exc.errcode) + self.assertEqual("M_INVALID_PARAM", exc.errcode) class MapUsernameTestCase(unittest.TestCase): -- cgit 1.5.1 From 69ce55c51082d03e549863f2149b4cf10cb1de19 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 11 Mar 2020 15:21:25 +0000 Subject: Don't filter out dummy events when we're checking the visibility of state --- synapse/handlers/message.py | 2 +- synapse/visibility.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0c84c6cec4..b743fc2dcc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -160,7 +160,7 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.storage, user_id, last_events, apply_retention_policies=False + self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] diff --git a/synapse/visibility.py b/synapse/visibility.py index a48a4f3dfe..1d538b206d 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -49,7 +49,7 @@ def filter_events_for_client( events, is_peeking=False, always_include_ids=frozenset(), - apply_retention_policies=True, + filter_send_to_client=True, ): """ Check which events a user is allowed to see. If the user can see the event but its @@ -65,10 +65,9 @@ def filter_events_for_client( events always_include_ids (set(event_id)): set of event ids to specifically include (unless sender is ignored) - apply_retention_policies (bool): Whether to filter out events that's older than - allowed by the room's retention policy. Useful when this function is called - to e.g. check whether a user should be allowed to see the state at a given - event rather than to know if it should send an event to a user's client(s). + filter_send_to_client (bool): Whether we're checking an event that's going to be + sent to a client. This might not always be the case since this function can + also be called to check whether a user can see the state at a given point. Returns: Deferred[list[synapse.events.EventBase]] @@ -96,7 +95,7 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) - if apply_retention_policies: + if not filter_send_to_client: room_ids = {e.room_id for e in events} retention_policies = {} @@ -119,7 +118,7 @@ def filter_events_for_client( the original event if they can see it as normal. """ - if event.type == "org.matrix.dummy_event": + if event.type == "org.matrix.dummy_event" and filter_send_to_client: return None if not event.is_state() and event.sender in ignore_list: @@ -134,7 +133,7 @@ def filter_events_for_client( # Don't try to apply the room's retention policy if the event is a state event, as # MSC1763 states that retention is only considered for non-state events. - if apply_retention_policies and not event.is_state(): + if filter_send_to_client and not event.is_state(): retention_policy = retention_policies[event.room_id] max_lifetime = retention_policy.get("max_lifetime") -- cgit 1.5.1 From 190ab593b7a2c0d79569758c0faa4d2442bc2c5f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Mar 2020 15:21:54 -0400 Subject: Use the proper error code when a canonical alias that does not exist is used. (#7109) --- changelog.d/7109.bugfix | 1 + synapse/handlers/message.py | 57 ++++++++++++++++++++++++++++++--------------- 2 files changed, 39 insertions(+), 19 deletions(-) create mode 100644 changelog.d/7109.bugfix (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7109.bugfix b/changelog.d/7109.bugfix new file mode 100644 index 0000000000..268de9978e --- /dev/null +++ b/changelog.d/7109.bugfix @@ -0,0 +1 @@ +Return the proper error (M_BAD_ALIAS) when a non-existant canonical alias is provided. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b743fc2dcc..522271eed1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -851,6 +851,38 @@ class EventCreationHandler(object): self.store.remove_push_actions_from_staging, event.event_id ) + @defer.inlineCallbacks + def _validate_canonical_alias( + self, directory_handler, room_alias_str, expected_room_id + ): + """ + Ensure that the given room alias points to the expected room ID. + + Args: + directory_handler: The directory handler object. + room_alias_str: The room alias to check. + expected_room_id: The room ID that the alias should point to. + """ + room_alias = RoomAlias.from_string(room_alias_str) + try: + mapping = yield directory_handler.get_association(room_alias) + except SynapseError as e: + # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. + if e.errcode == Codes.NOT_FOUND: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + raise + + if mapping["room_id"] != expected_room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, + ) + @defer.inlineCallbacks def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] @@ -905,15 +937,9 @@ class EventCreationHandler(object): room_alias_str = event.content.get("alias", None) directory_handler = self.hs.get_handlers().directory_handler if room_alias_str and room_alias_str != original_alias: - room_alias = RoomAlias.from_string(room_alias_str) - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" % (room_alias_str,), - Codes.BAD_ALIAS, - ) + yield self._validate_canonical_alias( + directory_handler, room_alias_str, event.room_id + ) # Check that alt_aliases is the proper form. alt_aliases = event.content.get("alt_aliases", []) @@ -931,16 +957,9 @@ class EventCreationHandler(object): new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) if new_alt_aliases: for alias_str in new_alt_aliases: - room_alias = RoomAlias.from_string(alias_str) - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" - % (room_alias_str,), - Codes.BAD_ALIAS, - ) + yield self._validate_canonical_alias( + directory_handler, alias_str, event.room_id + ) federation_handler = self.hs.get_handlers().federation_handler -- cgit 1.5.1 From 6b22921b195c24762cd7c02a8b8fad75791fce70 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 1 May 2020 15:15:36 +0100 Subject: async/await is_server_admin (#7363) --- changelog.d/7363.misc | 1 + synapse/api/auth.py | 9 +- synapse/federation/federation_client.py | 5 +- synapse/groups/groups_server.py | 64 +++++------ synapse/handlers/_base.py | 16 ++- synapse/handlers/directory.py | 51 ++++----- synapse/handlers/federation.py | 21 ++-- synapse/handlers/groups_local.py | 24 ++-- synapse/handlers/message.py | 83 +++++++------- synapse/handlers/profile.py | 39 ++++--- synapse/handlers/register.py | 49 ++++---- synapse/handlers/room.py | 121 ++++++++++---------- synapse/handlers/room_member.py | 127 ++++++++++----------- synapse/server_notices/consent_server_notices.py | 11 +- .../resource_limits_server_notices.py | 35 +++--- synapse/server_notices/server_notices_manager.py | 32 +++--- synapse/server_notices/server_notices_sender.py | 12 +- synapse/storage/data_stores/main/registration.py | 5 +- tests/handlers/test_profile.py | 60 ++++++---- tests/handlers/test_register.py | 29 +++-- .../test_resource_limits_server_notices.py | 48 ++++---- tests/test_federation.py | 6 +- 22 files changed, 410 insertions(+), 438 deletions(-) create mode 100644 changelog.d/7363.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7363.misc b/changelog.d/7363.misc new file mode 100644 index 0000000000..1e3cddde79 --- /dev/null +++ b/changelog.d/7363.misc @@ -0,0 +1 @@ +Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index c1ade1333b..c5d1eb952b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -537,8 +537,7 @@ class Auth(object): return defer.succeed(auth_ids) - @defer.inlineCallbacks - def check_can_change_room_list(self, room_id: str, user: UserID): + async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the published room list. @@ -547,17 +546,17 @@ class Auth(object): user """ - is_admin = yield self.is_server_admin(user) + is_admin = await self.is_server_admin(user) if is_admin: return True user_id = user.to_string() - yield self.check_user_in_room(room_id, user_id) + await self.check_user_in_room(room_id, user_id) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the # m.room.canonical_alias events - power_level_event = yield self.state.get_current_state( + power_level_event = await self.state.get_current_state( room_id, EventTypes.PowerLevels, "" ) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 58b13da616..687cd841ac 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -976,14 +976,13 @@ class FederationClient(FederationBase): return signed_events - @defer.inlineCallbacks - def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite(self, destinations, room_id, event_dict): for destination in destinations: if destination == self.server_name: continue try: - yield self.transport_layer.exchange_third_party_invite( + await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) return None diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 4f0dc0a209..4acb4fa489 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise NotImplementedError() - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from the group; either a user is leaving or an admin kicked them. """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) is_kick = False if requester_user_id != user_id: - is_admin = yield self.store.is_user_admin_in_group( + is_admin = await self.store.is_user_admin_in_group( group_id, requester_user_id ) if not is_admin: @@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler): is_kick = True - yield self.store.remove_user_from_group(group_id, user_id) + await self.store.remove_user_from_group(group_id, user_id) if is_kick: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) if not self.hs.is_mine_id(user_id): - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # Delete group if the last user has left - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) if not users: - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) return {} - @defer.inlineCallbacks - def create_group(self, group_id, requester_user_id, content): - group = yield self.check_group_is_ours(group_id, requester_user_id) + async def create_group(self, group_id, requester_user_id, content): + group = await self.check_group_is_ours(group_id, requester_user_id) logger.info("Attempting to create group with ID: %r", group_id) @@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if group: raise SynapseError(400, "Group already exists") - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) if not is_admin: @@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): long_description = profile.get("long_description") user_profile = content.get("user_profile", {}) - yield self.store.create_group( + await self.store.create_group( group_id, requester_user_id, name=name, @@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if not self.hs.is_mine_id(requester_user_id): remote_attestation = content["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, user_id=requester_user_id, group_id=group_id ) @@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): local_attestation = None remote_attestation = None - yield self.store.add_user_to_group( + await self.store.add_user_to_group( group_id, requester_user_id, is_admin=True, @@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): ) if not self.hs.is_mine_id(requester_user_id): - yield self.store.add_remote_profile_cache( + await self.store.add_remote_profile_cache( requester_user_id, displayname=user_profile.get("displayname"), avatar_url=user_profile.get("avatar_url"), @@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {"group_id": group_id} - @defer.inlineCallbacks - def delete_group(self, group_id, requester_user_id): + async def delete_group(self, group_id, requester_user_id): """Deletes a group, kicking out all current members. Only group admins or server admins can call this request @@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): Deferred """ - yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) + await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) # Only server admins or group admins can delete groups. - is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) + is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) if not is_admin: - is_admin = yield self.auth.is_server_admin( + is_admin = await self.auth.is_server_admin( UserID.from_string(requester_user_id) ) @@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler): raise SynapseError(403, "User is not an admin") # Before deleting the group lets kick everyone out of it - users = yield self.store.get_users_in_group(group_id, include_private=True) + users = await self.store.get_users_in_group(group_id, include_private=True) - @defer.inlineCallbacks - def _kick_user_from_group(user_id): + async def _kick_user_from_group(user_id): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() - yield groups_local.user_removed_from_group(group_id, user_id, {}) + await groups_local.user_removed_from_group(group_id, user_id, {}) else: - yield self.transport_client.remove_user_from_group_notification( + await self.transport_client.remove_user_from_group_notification( get_domain_from_id(user_id), group_id, user_id, {} ) - yield self.store.maybe_delete_remote_profile_cache(user_id) + await self.store.maybe_delete_remote_profile_cache(user_id) # We kick users out in the order of: # 1. Non-admins @@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler): else: non_admins.append(u["user_id"]) - yield concurrently_execute(_kick_user_from_group, non_admins, 10) - yield concurrently_execute(_kick_user_from_group, admins, 10) - yield _kick_user_from_group(requester_user_id) + await concurrently_execute(_kick_user_from_group, non_admins, 10) + await concurrently_execute(_kick_user_from_group, admins, 10) + await _kick_user_from_group(requester_user_id) - yield self.store.delete_group(group_id) + await self.store.delete_group(group_id) def _parse_join_policy_from_contents(content): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 51413d910e..3b781d9836 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -126,30 +126,28 @@ class BaseHandler(object): retry_after_ms=int(1000 * (time_allowed - time_now)) ) - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, context=None): + async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids() - current_state = yield self.store.get_events( + current_state_ids = await context.get_current_state_ids() + current_state = await self.store.get_events( list(current_state_ids.values()) ) else: - current_state = yield self.state_handler.get_current_state( + current_state = await self.state_handler.get_current_state( event.room_id ) current_state = list(current_state.values()) logger.info("maybe_kick_guest_users %r", current_state) - yield self.kick_guest_users(current_state) + await self.kick_guest_users(current_state) - @defer.inlineCallbacks - def kick_guest_users(self, current_state): + async def kick_guest_users(self, current_state): for member_event in current_state: try: if member_event.type != EventTypes.Member: @@ -180,7 +178,7 @@ class BaseHandler(object): # homeserver. requester = synapse.types.create_requester(target_user, is_guest=True) handler = self.hs.get_room_member_handler() - yield handler.update_membership( + await handler.update_membership( requester, target_user, member_event.room_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 53e5f585d9..f2f16b1e43 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler): room_alias, room_id, servers, creator=creator ) - @defer.inlineCallbacks - def create_association( + async def create_association( self, requester: Requester, room_alias: RoomAlias, @@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler): else: # Server admins are not subject to the same constraints as normal # users when creating an alias (e.g. being in the room). - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) if (self.require_membership and check_membership) and not is_admin: - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) if room_id not in rooms_for_user: raise AuthError( 403, "You must be in the room to create an alias for it" @@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") - can_create = yield self.can_modify_alias(room_alias, user_id=user_id) + can_create = await self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( 400, @@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - yield self._create_association(room_alias, room_id, servers, creator=user_id) + await self._create_association(room_alias, room_id, servers, creator=user_id) - @defer.inlineCallbacks - def delete_association(self, requester: Requester, room_alias: RoomAlias): + async def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler): user_id = requester.user.to_string() try: - can_delete = yield self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, user_id) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler): if not can_delete: raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) + can_delete = await self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( 400, @@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - room_id = yield self._delete_association(room_alias) + room_id = await self._delete_association(room_alias) try: - yield self._update_canonical_alias(requester, user_id, room_id, room_alias) + await self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - @defer.inlineCallbacks - def _update_canonical_alias( + async def _update_canonical_alias( self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. """ - alias_event = yield self.state.get_current_state( + alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" ) @@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler): del content["alt_aliases"] if send_update: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler): # either no interested services, or no service with an exclusive lock return defer.succeed(True) - @defer.inlineCallbacks - def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): """Determine whether a user can delete an alias. One of the following must be true: @@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler): for the current room. """ - creator = yield self.store.get_room_alias_creator(alias.to_string()) + creator = await self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True # Resolve the alias to the corresponding room. - room_mapping = yield self.get_association(alias) + room_mapping = await self.get_association(alias) room_id = room_mapping["room_id"] if not room_id: return False - res = yield self.auth.check_can_change_room_list( + res = await self.auth.check_can_change_room_list( room_id, UserID.from_string(user_id) ) return res - @defer.inlineCallbacks - def edit_published_room_list( + async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str ): """Edit the entry of the room in the published room list. @@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler): 403, "This user is not permitted to publish rooms to the room list" ) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Unknown room") - can_change_room_list = yield self.auth.check_can_change_room_list( + can_change_room_list = await self.auth.check_can_change_room_list( room_id, requester.user ) if not can_change_room_list: @@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler): making_public = visibility == "public" if making_public: - room_aliases = yield self.store.get_aliases_for_room(room_id) - canonical_alias = yield self.store.get_canonical_alias_for_room(room_id) + room_aliases = await self.store.get_aliases_for_room(room_id) + canonical_alias = await self.store.get_canonical_alias_for_room(room_id) if canonical_alias: room_aliases.append(canonical_alias) @@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") - yield self.store.set_room_is_public(room_id, making_public) + await self.store.set_room_is_public(room_id, making_public) @defer.inlineCallbacks def edit_published_appservice_room_list( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 41b96c0a73..4e5c645525 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler): "missing": [e.event_id for e in missing_locals], } - @defer.inlineCallbacks @log_function - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): third_party_invite = {"signed": signed} @@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler): "state_key": target_user_id, } - if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): - room_version = yield self.store.get_room_version_id(room_id) + if await self.auth.check_host_in_room(room_id, self.hs.hostname): + room_version = await self.store.get_room_version_id(room_id) builder = self.event_builder_factory.new(room_version, event_dict) EventValidator().validate_builder(builder) - event, context = yield self.event_creation_handler.create_new_client_event( + event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - event, context = yield self.add_display_name_to_third_party_invite( + event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) @@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = self.hs.hostname try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, context) + await self._check_signature(event, context) # We retrieve the room member handler here as to not cause a cyclic dependency member_handler = self.hs.get_room_member_handler() - yield member_handler.send_membership_event(None, event, context) + await member_handler.send_membership_event(None, event, context) else: destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} - yield self.federation_client.forward_third_party_invite( + await self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ad22415782..ca5c83811a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - @defer.inlineCallbacks - def create_group(self, group_id, user_id, content): + async def create_group(self, group_id, user_id, content): """Create a group """ logger.info("Asking to create group with ID: %r", group_id) if self.is_mine_id(group_id): - res = yield self.groups_server_handler.create_group( + res = await self.groups_server_handler.create_group( group_id, user_id, content ) local_attestation = None @@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = self.attestations.create_attestation(group_id, user_id) content["attestation"] = local_attestation - content["user_profile"] = yield self.profile_handler.get_profile(user_id) + content["user_profile"] = await self.profile_handler.get_profile(user_id) try: - res = yield self.transport_client.create_group( + res = await self.transport_client.create_group( get_domain_from_id(group_id), group_id, user_id, content ) except HttpResponseException as e: @@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): raise SynapseError(502, "Failed to contact group server") remote_attestation = res["attestation"] - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( remote_attestation, group_id=group_id, user_id=user_id, @@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): ) is_publicised = content.get("publicise", False) - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="join", @@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} - @defer.inlineCallbacks - def remove_user_from_group(self, group_id, user_id, requester_user_id, content): + async def remove_user_from_group( + self, group_id, user_id, requester_user_id, content + ): """Remove a user from a group """ if user_id == requester_user_id: - token = yield self.store.register_user_group_membership( + token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" ) self.notifier.on_new_event("groups_key", token, users=[user_id]) @@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): # retry if the group server is currently down. if self.is_mine_id(group_id): - res = yield self.groups_server_handler.remove_user_from_group( + res = await self.groups_server_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) else: content["requester_user_id"] = requester_user_id try: - res = yield self.transport_client.remove_user_from_group( + res = await self.transport_client.remove_user_from_group( get_domain_from_id(group_id), group_id, requester_user_id, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 522271eed1..a324f09340 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -626,8 +626,7 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - @defer.inlineCallbacks - def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event(self, requester, event, context, ratelimit=True): """ Persists and notifies local clients and federation of an event. @@ -647,7 +646,7 @@ class EventCreationHandler(object): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = yield self.deduplicate_state_event(event, context) + prev_state = await self.deduplicate_state_event(event, context) if prev_state is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", @@ -656,7 +655,7 @@ class EventCreationHandler(object): ) return prev_state - yield self.handle_new_client_event( + await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -683,8 +682,7 @@ class EventCreationHandler(object): return prev_event return - @defer.inlineCallbacks - def create_and_send_nonmember_event( + async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None ): """ @@ -698,8 +696,8 @@ class EventCreationHandler(object): # a situation where event persistence can't keep up, causing # extremities to pile up, which in turn leads to state resolution # taking longer. - with (yield self.limiter.queue(event_dict["room_id"])): - event, context = yield self.create_event( + with (await self.limiter.queue(event_dict["room_id"])): + event, context = await self.create_event( requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) @@ -709,7 +707,7 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) return event @@ -770,8 +768,7 @@ class EventCreationHandler(object): return (event, context) @measure_func("handle_new_client_event") - @defer.inlineCallbacks - def handle_new_client_event( + async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Processes a new event. This includes checking auth, persisting it, @@ -794,9 +791,9 @@ class EventCreationHandler(object): ): room_version = event.content.get("room_version", RoomVersions.V1.identifier) else: - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) - event_allowed = yield self.third_party_event_rules.check_event_allowed( + event_allowed = await self.third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -805,7 +802,7 @@ class EventCreationHandler(object): ) try: - yield self.auth.check_from_context(room_version, event, context) + await self.auth.check_from_context(room_version, event, context) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) raise err @@ -818,7 +815,7 @@ class EventCreationHandler(object): logger.exception("Failed to encode content: %r", event.content) raise - yield self.action_generator.handle_push_actions_for_event(event, context) + await self.action_generator.handle_push_actions_for_event(event, context) # reraise does not allow inlineCallbacks to preserve the stacktrace, so we # hack around with a try/finally instead. @@ -826,7 +823,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - yield self.send_event_to_master( + await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -838,7 +835,7 @@ class EventCreationHandler(object): success = True return - yield self.persist_and_notify_client_event( + await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) @@ -883,8 +880,7 @@ class EventCreationHandler(object): Codes.BAD_ALIAS, ) - @defer.inlineCallbacks - def persist_and_notify_client_event( + async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] ): """Called when we have fully built the event, have already @@ -901,7 +897,7 @@ class EventCreationHandler(object): # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -913,11 +909,11 @@ class EventCreationHandler(object): original_event and event.sender != original_event.sender ) - yield self.base_handler.ratelimit( + await self.base_handler.ratelimit( requester, is_admin_redaction=is_admin_redaction ) - yield self.base_handler.maybe_kick_guest_users(event, context) + await self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Validate a newly added alias or newly added alt_aliases. @@ -927,7 +923,7 @@ class EventCreationHandler(object): original_event_id = event.unsigned.get("replaces_state") if original_event_id: - original_event = yield self.store.get_event(original_event_id) + original_event = await self.store.get_event(original_event_id) if original_event: original_alias = original_event.content.get("alias", None) @@ -937,7 +933,7 @@ class EventCreationHandler(object): room_alias_str = event.content.get("alias", None) directory_handler = self.hs.get_handlers().directory_handler if room_alias_str and room_alias_str != original_alias: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, room_alias_str, event.room_id ) @@ -957,7 +953,7 @@ class EventCreationHandler(object): new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) if new_alt_aliases: for alias_str in new_alt_aliases: - yield self._validate_canonical_alias( + await self._validate_canonical_alias( directory_handler, alias_str, event.room_id ) @@ -969,7 +965,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids() + current_state_ids = await context.get_current_state_ids() state_to_include_ids = [ e_id @@ -978,7 +974,7 @@ class EventCreationHandler(object): or k == (EventTypes.Member, event.sender) ] - state_to_include = yield self.store.get_events(state_to_include_ids) + state_to_include = await self.store.get_events(state_to_include_ids) event.unsigned["invite_room_state"] = [ { @@ -996,8 +992,8 @@ class EventCreationHandler(object): # way? If we have been invited by a remote server, we need # to get them to sign the event. - returned_invite = yield defer.ensureDeferred( - federation_handler.send_invite(invitee.domain, event) + returned_invite = await federation_handler.send_invite( + invitee.domain, event ) event.unsigned.pop("room_state", None) @@ -1005,7 +1001,7 @@ class EventCreationHandler(object): event.signatures.update(returned_invite.signatures) if event.type == EventTypes.Redaction: - original_event = yield self.store.get_event( + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -1021,14 +1017,14 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids() - auth_events_ids = yield self.auth.compute_auth_events( + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} - room_version = yield self.store.get_room_version_id(event.room_id) + room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if event_auth.check_redaction( @@ -1047,11 +1043,11 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( + event_stream_id, max_stream_id = await self.storage.persistence.persist_event( event, context=context ) @@ -1059,7 +1055,7 @@ class EventCreationHandler(object): # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): try: @@ -1083,13 +1079,12 @@ class EventCreationHandler(object): except Exception: logger.exception("Error bumping presence active time") - @defer.inlineCallbacks - def _send_dummy_events_to_fill_extremities(self): + async def _send_dummy_events_to_fill_extremities(self): """Background task to send dummy events into rooms that have a large number of extremities """ self._expire_rooms_to_exclude_from_dummy_event_insertion() - room_ids = yield self.store.get_rooms_with_many_extremities( + room_ids = await self.store.get_rooms_with_many_extremities( min_count=10, limit=5, room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), @@ -1099,9 +1094,9 @@ class EventCreationHandler(object): # For each room we need to find a joined member we can use to send # the dummy event with. - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - members = yield self.state.get_current_users_in_room( + members = await self.state.get_current_users_in_room( room_id, latest_event_ids=latest_event_ids ) dummy_event_sent = False @@ -1110,7 +1105,7 @@ class EventCreationHandler(object): continue requester = create_requester(user_id) try: - event, context = yield self.create_event( + event, context = await self.create_event( requester, { "type": "org.matrix.dummy_event", @@ -1123,7 +1118,7 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False - yield self.send_nonmember_event( + await self.send_nonmember_event( requester, event, context, ratelimit=False ) dummy_event_sent = True diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6aa1c0f5e0..302efc1b9a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler): return result["displayname"] - @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname, by_admin=False): + async def set_displayname( + self, target_user, requester, new_displayname, by_admin=False + ): """Set the displayname of a user Args: @@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's displayname") if not by_admin and not self.hs.config.enable_set_displayname: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( 400, @@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname(target_user.localpart, new_displayname) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def get_avatar_url(self, target_user): @@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler): return result["avatar_url"] - @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): + async def set_avatar_url( + self, target_user, requester, new_avatar_url, by_admin=False + ): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): @@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's avatar_url") if not by_admin and not self.hs.config.enable_set_avatar_url: - profile = yield self.store.get_profileinfo(target_user.localpart) + profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN @@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler): if by_admin: requester = create_requester(target_user) - yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(target_user.localpart) - yield self.user_directory_handler.handle_local_profile_change( + profile = await self.store.get_profileinfo(target_user.localpart) + await self.user_directory_handler.handle_local_profile_change( target_user.to_string(), profile ) - yield self._update_join_states(requester, target_user) + await self._update_join_states(requester, target_user) @defer.inlineCallbacks def on_profile_query(self, args): @@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler): return response - @defer.inlineCallbacks - def _update_join_states(self, requester, target_user): + async def _update_join_states(self, requester, target_user): if not self.hs.is_mine(target_user): return - yield self.ratelimit(requester) + await self.ratelimit(requester) - room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) + room_ids = await self.store.get_rooms_for_user(target_user.to_string()) for room_id in room_ids: handler = self.hs.get_room_member_handler() try: # Assume the target_user isn't a guest, # because we don't let guests set profile or avatar data. - yield handler.update_membership( + await handler.update_membership( requester, target_user, room_id, diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 3a65b46ecd..1e6bdac0ad 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler): """Registers a new client on the server. Args: - localpart : The local part of the user ID to register. If None, + localpart: The local part of the user ID to register. If None, one will be generated. - password (unicode) : The password to assign to this user so they can + password (unicode): The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). user_type (str|None): type of user. One of the values from @@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler): fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield self._auto_join_rooms(user_id) + yield defer.ensureDeferred(self._auto_join_rooms(user_id)) else: logger.info( "Skipping auto-join for %s because consent is required at registration", @@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler): return user_id - @defer.inlineCallbacks - def _auto_join_rooms(self, user_id): + async def _auto_join_rooms(self, user_id): """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler): # that an auto-generated support or bot user is not a real user and will never be # the user to create the room should_auto_create_rooms = False - is_real_user = yield self.store.is_real_user(user_id) + is_real_user = await self.store.is_real_user(user_id) if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = yield self.store.count_real_users() + count = await self.store.count_real_users() should_auto_create_rooms = count == 1 for r in self.hs.config.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) @@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler): # getting the RoomCreationHandler during init gives a dependency # loop - yield self.hs.get_room_creation_handler().create_room( + await self.hs.get_room_creation_handler().create_room( fake_requester, config={ "preset": "public_chat", @@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) else: - yield self._join_user_to_room(fake_requester, r) + await self._join_user_to_room(fake_requester, r) except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - @defer.inlineCallbacks - def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted Args: user_id (str): The user to join """ - yield self._auto_join_rooms(user_id) + await self._auto_join_rooms(user_id) @defer.inlineCallbacks def appservice_register(self, user_localpart, as_token): @@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id += 1 return str(id) - @defer.inlineCallbacks - def _join_user_to_room(self, requester, room_identifier): + async def _join_user_to_room(self, requester, room_identifier): room_member_handler = self.hs.get_room_member_handler() if RoomID.is_valid(room_identifier): room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( + room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( room_alias ) room_id = room_id.to_string() @@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler): 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield room_member_handler.update_membership( + await room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler): return (device_id, access_token) - @defer.inlineCallbacks - def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions(self, user_id, auth_result, access_token): """A user has completed registration Args: @@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler): device, or None if `inhibit_login` enabled. """ if self.hs.config.worker_app: - yield self._post_registration_client( + await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token ) return @@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(user_id) + await self.store.upsert_monthly_active_user(user_id) - yield self._register_email_threepid(user_id, threepid, access_token) + await self._register_email_threepid(user_id, threepid, access_token) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] - yield self._register_msisdn_threepid(user_id, threepid) + await self._register_msisdn_threepid(user_id, threepid) if auth_result and LoginType.TERMS in auth_result: - yield self._on_user_consented(user_id, self.hs.config.user_consent_version) + await self._on_user_consented(user_id, self.hs.config.user_consent_version) - @defer.inlineCallbacks - def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id, consent_version): """A user consented to the terms on registration Args: @@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler): consented to. """ logger.info("%s has consented to the privacy policy", user_id) - yield self.store.user_set_consent_version(user_id, consent_version) - yield self.post_consent_actions(user_id) + await self.store.user_set_consent_version(user_id, consent_version) + await self.post_consent_actions(user_id) @defer.inlineCallbacks def _register_email_threepid(self, user_id, threepid, token): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3d10e4b2d9..da12df7f53 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _upgrade_room( + async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): user_id = requester.user.to_string() # start by allocating a new room id - r = yield self.store.get_room(old_room_id) + r = await self.store.get_room(old_room_id) if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) - new_room_id = yield self._generate_room_id( + new_room_id = await self._generate_room_id( creator_id=user_id, is_public=r["is_public"], room_version=new_version, ) @@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler): ( tombstone_event, tombstone_context, - ) = yield self.event_creation_handler.create_event( + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Tombstone, @@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler): }, token_id=requester.access_token_id, ) - old_room_version = yield self.store.get_room_version_id(old_room_id) - yield self.auth.check_from_context( + old_room_version = await self.store.get_room_version_id(old_room_id) + await self.auth.check_from_context( old_room_version, tombstone_event, tombstone_context ) - yield self.clone_existing_room( + await self.clone_existing_room( requester, old_room_id=old_room_id, new_room_id=new_room_id, @@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler): ) # now send the tombstone - yield self.event_creation_handler.send_nonmember_event( + await self.event_creation_handler.send_nonmember_event( requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids() + old_room_state = await tombstone_context.get_current_state_ids() # update any aliases - yield self._move_aliases_to_new_room( + await self._move_aliases_to_new_room( requester, old_room_id, new_room_id, old_room_state ) # Copy over user push rules, tags and migrate room directory state - yield self.room_member_handler.transfer_room_state_on_room_upgrade( + await self.room_member_handler.transfer_room_state_on_room_upgrade( old_room_id, new_room_id ) # finally, shut down the PLs in the old room, and update them in the new # room. - yield self._update_upgraded_room_pls( + await self._update_upgraded_room_pls( requester, old_room_id, new_room_id, old_room_state, ) return new_room_id - @defer.inlineCallbacks - def _update_upgraded_room_pls( + async def _update_upgraded_room_pls( self, requester: Requester, old_room_id: str, @@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler): ) return - old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) + old_room_pl_state = await self.store.get_event(old_room_pl_event_id) # we try to stop regular users from speaking by setting the PL required # to send regular events and invites to 'Moderator' level. That's normally @@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler): if updated: try: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.PowerLevels, @@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - @defer.inlineCallbacks - def clone_existing_room( + async def clone_existing_room( self, requester: Requester, old_room_id: str, @@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler): # Check if old room was non-federatable # Get old room's create event - old_room_create_event = yield self.store.get_create_event_for_room(old_room_id) + old_room_create_event = await self.store.get_create_event_for_room(old_room_id) # Check if the create event specified a non-federatable room if not old_room_create_event.content.get("m.federate", True): @@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler): (EventTypes.PowerLevels, ""), ) - old_room_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types(types_to_copy) ) # map from event_id to BaseEvent - old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) + old_room_state_events = await self.store.get_events(old_room_state_ids.values()) for k, old_event_id in iteritems(old_room_state_ids): old_event = old_room_state_events.get(old_event_id) @@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler): if current_power_level < needed_power_level: power_levels["users"][user_id] = needed_power_level - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, new_room_id, # we expect to override all the presets with initial_state, so this is @@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler): ) # Transfer membership events - old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( + old_room_member_state_ids = await self.store.get_filtered_current_state_ids( old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) ) # map from event_id to BaseEvent - old_room_member_state_events = yield self.store.get_events( + old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) for k, old_event in iteritems(old_room_member_state_events): @@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler): "membership" in old_event.content and old_event.content["membership"] == "ban" ): - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(old_event["state_key"]), new_room_id, @@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler): # XXX invites/joins # XXX 3pid invites - @defer.inlineCallbacks - def _move_aliases_to_new_room( + async def _move_aliases_to_new_room( self, requester: Requester, old_room_id: str, @@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler): ): directory_handler = self.hs.get_handlers().directory_handler - aliases = yield self.store.get_aliases_for_room(old_room_id) + aliases = await self.store.get_aliases_for_room(old_room_id) # check to see if we have a canonical alias. canonical_alias_event = None canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) if canonical_alias_event_id: - canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) + canonical_alias_event = await self.store.get_event(canonical_alias_event_id) # first we try to remove the aliases from the old room (we suppress sending # the room_aliases event until the end). @@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler): for alias_str in aliases: alias = RoomAlias.from_string(alias_str) try: - yield directory_handler.delete_association(requester, alias) + await directory_handler.delete_association(requester, alias) removed_aliases.append(alias_str) except SynapseError as e: logger.warning("Unable to remove alias %s from old room: %s", alias, e) @@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler): # we can now add any aliases we successfully removed to the new room. for alias in removed_aliases: try: - yield directory_handler.create_association( + await directory_handler.create_association( requester, RoomAlias.from_string(alias), new_room_id, @@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler): # alias event for the new room with a copy of the information. try: if canonical_alias_event: - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.CanonicalAlias, @@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler): # we returned the new room to the client at this point. logger.error("Unable to send updated alias events in new room: %s", e) - @defer.inlineCallbacks - def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): + async def create_room( + self, requester, config, ratelimit=True, creator_join_profile=None + ): """ Creates a new room. Args: @@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) if ( self._server_notices_mxid is not None @@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler): # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) # Check whether the third party rules allows/changes the room create # request. - event_allowed = yield self.third_party_event_rules.on_create_room( + event_allowed = await self.third_party_event_rules.on_create_room( requester, config, is_requester_admin=is_requester_admin ) if not event_allowed: @@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - yield self.ratelimit(requester) + await self.ratelimit(requester) room_version_id = config.get( "room_version", self.config.default_room_version.identifier @@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(400, "Invalid characters in room alias") room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) - mapping = yield self.store.get_association_from_room_alias(room_alias) + mapping = await self.store.get_association_from_room_alias(room_alias) if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) @@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) - yield self.event_creation_handler.assert_accepted_privacy_policy(requester) + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") if ( @@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler): visibility = config.get("visibility", None) is_public = visibility == "public" - room_id = yield self._generate_room_id( + room_id = await self._generate_room_id( creator_id=user_id, is_public=is_public, room_version=room_version, ) directory_handler = self.hs.get_handlers().directory_handler if room_alias: - yield directory_handler.create_association( + await directory_handler.create_association( requester=requester, room_id=room_id, room_alias=room_alias, @@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - yield self._send_events_for_new_room( + await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - yield self.hs.get_room_member_handler().do_3pid_invite( + await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler): return result - @defer.inlineCallbacks - def _send_events_for_new_room( + async def _send_events_for_new_room( self, creator, # A Requester object. room_id, @@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler): return e - @defer.inlineCallbacks - def send(etype, content, **kwargs): + async def send(etype, content, **kwargs): event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) @@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler): event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} creation_content.update({"creator": creator_id}) - yield send(etype=EventTypes.Create, content=creation_content) + await send(etype=EventTypes.Create, content=creation_content) logger.debug("Sending %s in new room", EventTypes.Member) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( creator, creator.user, room_id, @@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - yield send(etype=EventTypes.PowerLevels, content=pl_content) + await send(etype=EventTypes.PowerLevels, content=pl_content) else: power_level_content = { "users": {creator_id: 100}, @@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - yield send(etype=EventTypes.PowerLevels, content=power_level_content) + await send(etype=EventTypes.PowerLevels, content=power_level_content) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - yield send( + await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - yield send( + await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - yield send( + await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - yield send( + await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - yield send(etype=etype, state_key=state_key, content=content) + await send(etype=etype, state_key=state_key, content=content) @defer.inlineCallbacks def _generate_room_id( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index c3ee8db4f0..53b49bc15f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -142,8 +142,7 @@ class RoomMemberHandler(object): """ raise NotImplementedError() - @defer.inlineCallbacks - def _local_membership_update( + async def _local_membership_update( self, requester, target, @@ -164,7 +163,7 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - event, context = yield self.event_creation_handler.create_event( + event, context = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -182,18 +181,18 @@ class RoomMemberHandler(object): ) # Check if this event matches the previous membership event for the user. - duplicate = yield self.event_creation_handler.deduplicate_state_event( + duplicate = await self.event_creation_handler.deduplicate_state_event( event, context ) if duplicate is not None: # Discard the new event since this membership change is a no-op. return duplicate - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -203,15 +202,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target, room_id) + await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target, room_id) + await self._user_left_room(target, room_id) return event @@ -253,8 +252,7 @@ class RoomMemberHandler(object): for tag, tag_content in room_tags.items(): yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) - @defer.inlineCallbacks - def update_membership( + async def update_membership( self, requester, target, @@ -269,8 +267,8 @@ class RoomMemberHandler(object): ): key = (room_id,) - with (yield self.member_linearizer.queue(key)): - result = yield self._update_membership( + with (await self.member_linearizer.queue(key)): + result = await self._update_membership( requester, target, room_id, @@ -285,8 +283,7 @@ class RoomMemberHandler(object): return result - @defer.inlineCallbacks - def _update_membership( + async def _update_membership( self, requester, target, @@ -321,7 +318,7 @@ class RoomMemberHandler(object): # if this is a join with a 3pid signature, we may need to turn a 3pid # invite into a normal invite before we can handle the join. if third_party_signed is not None: - yield self.federation_handler.exchange_third_party_invite( + await self.federation_handler.exchange_third_party_invite( third_party_signed["sender"], target.to_string(), room_id, @@ -332,7 +329,7 @@ class RoomMemberHandler(object): remote_room_hosts = [] if effective_membership_state not in ("leave", "ban"): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") @@ -351,7 +348,7 @@ class RoomMemberHandler(object): is_requester_admin = True else: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: if self.config.block_non_admin_invites: @@ -370,9 +367,9 @@ class RoomMemberHandler(object): if block_invite: raise SynapseError(403, "Invites have been disabled on this server") - latest_event_ids = yield self.store.get_prev_events_for_room(room_id) + latest_event_ids = await self.store.get_prev_events_for_room(room_id) - current_state_ids = yield self.state_handler.get_current_state_ids( + current_state_ids = await self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids ) @@ -381,7 +378,7 @@ class RoomMemberHandler(object): # transitions and generic otherwise old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) if old_state_id: - old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None if action == "unban" and old_membership != "ban": raise SynapseError( @@ -413,7 +410,7 @@ class RoomMemberHandler(object): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = yield self._is_server_notice_room(room_id) + is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( http_client.FORBIDDEN, @@ -424,18 +421,18 @@ class RoomMemberHandler(object): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = yield self._is_host_in_room(current_state_ids) + is_host_in_room = await self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(current_state_ids) + guest_can_join = await self._can_guest_join(current_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if not is_host_in_room: - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): remote_room_hosts.append(inviter.domain) @@ -443,13 +440,13 @@ class RoomMemberHandler(object): profile = self.profile_handler if not content_specified: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + content["displayname"] = await profile.get_displayname(target) + content["avatar_url"] = await profile.get_avatar_url(target) if requester.is_guest: content["kind"] = "guest" - remote_join_response = yield self._remote_join( + remote_join_response = await self._remote_join( requester, remote_room_hosts, room_id, target, content ) @@ -458,7 +455,7 @@ class RoomMemberHandler(object): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: # perhaps we've been invited - inviter = yield self._get_inviter(target.to_string(), room_id) + inviter = await self._get_inviter(target.to_string(), room_id) if not inviter: raise SynapseError(404, "Not a known room") @@ -472,12 +469,12 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = yield self._remote_reject_invite( + res = await self._remote_reject_invite( requester, remote_room_hosts, room_id, target, content, ) return res - res = yield self._local_membership_update( + res = await self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -572,8 +569,7 @@ class RoomMemberHandler(object): ) continue - @defer.inlineCallbacks - def send_membership_event(self, requester, event, context, ratelimit=True): + async def send_membership_event(self, requester, event, context, ratelimit=True): """ Change the membership status of a user in a room. @@ -599,27 +595,27 @@ class RoomMemberHandler(object): else: requester = types.create_requester(target_user) - prev_event = yield self.event_creation_handler.deduplicate_state_event( + prev_event = await self.event_creation_handler.deduplicate_state_event( event, context ) if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: - guest_can_join = yield self._can_guest_join(prev_state_ids) + guest_can_join = await self._can_guest_join(prev_state_ids) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") if event.membership not in (Membership.LEAVE, Membership.BAN): - is_blocked = yield self.store.is_room_blocked(room_id) + is_blocked = await self.store.is_room_blocked(room_id) if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - yield self.event_creation_handler.handle_new_client_event( + await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) @@ -633,15 +629,15 @@ class RoomMemberHandler(object): # info. newly_joined = True if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: - yield self._user_joined_room(target_user, room_id) + await self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = yield self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: - yield self._user_left_room(target_user, room_id) + await self._user_left_room(target_user, room_id) @defer.inlineCallbacks def _can_guest_join(self, current_state_ids): @@ -699,8 +695,7 @@ class RoomMemberHandler(object): if invite: return UserID.from_string(invite.sender) - @defer.inlineCallbacks - def do_3pid_invite( + async def do_3pid_invite( self, room_id, inviter, @@ -712,7 +707,7 @@ class RoomMemberHandler(object): id_access_token=None, ): if self.config.block_non_admin_invites: - is_requester_admin = yield self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -720,9 +715,9 @@ class RoomMemberHandler(object): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - yield self.base_handler.ratelimit(requester) + await self.base_handler.ratelimit(requester) - can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id ) if not can_invite: @@ -737,16 +732,16 @@ class RoomMemberHandler(object): 403, "Looking up third-party identifiers is denied from this server" ) - invitee = yield self.identity_handler.lookup_3pid( + invitee = await self.identity_handler.lookup_3pid( id_server, medium, address, id_access_token ) if invitee: - yield self.update_membership( + await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - yield self._make_and_store_3pid_invite( + await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -757,8 +752,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - @defer.inlineCallbacks - def _make_and_store_3pid_invite( + async def _make_and_store_3pid_invite( self, requester, id_server, @@ -769,7 +763,7 @@ class RoomMemberHandler(object): txn_id, id_access_token=None, ): - room_state = yield self.state_handler.get_current_state(room_id) + room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" inviter_avatar_url = "" @@ -807,7 +801,7 @@ class RoomMemberHandler(object): public_keys, fallback_public_key, display_name, - ) = yield self.identity_handler.ask_id_server_for_third_party_invite( + ) = await self.identity_handler.ask_id_server_for_third_party_invite( requester=requester, id_server=id_server, medium=medium, @@ -823,7 +817,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - yield self.event_creation_handler.create_and_send_nonmember_event( + await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity - @defer.inlineCallbacks - def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: # Fetch the room complexity - too_complex = yield self._is_remote_room_too_complex( + too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts ) if too_complex is True: @@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - yield defer.ensureDeferred( - self.federation_handler.do_invite_join( - remote_room_hosts, room_id, user.to_string(), content - ) + await self.federation_handler.do_invite_join( + remote_room_hosts, room_id, user.to_string(), content ) - yield self._user_joined_room(user, room_id) + await self._user_joined_room(user, room_id) # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. @@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return # Check again, but with the local state events - too_complex = yield self._is_local_room_too_complex(room_id) + too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. @@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # The room is too large. Leave. requester = types.create_requester(user, None, False, None) - yield self.update_membership( + await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) raise SynapseError( @@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler): def _user_joined_room(self, target, room_id): """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + return defer.succeed(user_joined_room(self.distributor, target, room_id)) def _user_left_room(self, target, room_id): """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + return defer.succeed(user_left_room(self.distributor, target, room_id)) @defer.inlineCallbacks def forget(self, user, room_id): diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 5736c56032..3bf330da49 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -16,8 +16,6 @@ import logging from six import iteritems, string_types -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.api.urls import ConsentURIBuilder from synapse.config import ConfigError @@ -59,8 +57,7 @@ class ConsentServerNotices(object): self._consent_uri_builder = ConsentURIBuilder(hs.config) - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, and does so if so Args: @@ -78,7 +75,7 @@ class ConsentServerNotices(object): return self._users_in_progress.add(user_id) try: - u = yield self._store.get_user_by_id(user_id) + u = await self._store.get_user_by_id(user_id) if u["is_guest"] and not self._send_to_guests: # don't send to guests @@ -100,8 +97,8 @@ class ConsentServerNotices(object): content = copy_with_str_subst( self._server_notice_content, {"consent_uri": consent_uri} ) - yield self._server_notices_manager.send_notice(user_id, content) - yield self._store.user_set_consent_server_notice_sent( + await self._server_notices_manager.send_notice(user_id, content) + await self._store.user_set_consent_server_notice_sent( user_id, self._current_consent_version ) except SynapseError as e: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index ce4a828894..971771b8b2 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -50,8 +50,7 @@ class ResourceLimitsServerNotices(object): self._notifier = hs.get_notifier() - @defer.inlineCallbacks - def maybe_send_server_notice_to_user(self, user_id): + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, this will be true in two cases. 1. The server has reached its limit does not reflect this @@ -74,13 +73,13 @@ class ResourceLimitsServerNotices(object): # Don't try and send server notices unless they've been enabled return - timestamp = yield self._store.user_last_seen_monthly_active(user_id) + timestamp = await self._store.user_last_seen_monthly_active(user_id) if timestamp is None: # This user will be blocked from receiving the notice anyway. # In practice, not sure we can ever get here return - room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( + room_id = await self._server_notices_manager.get_or_create_notice_room_for_user( user_id ) @@ -88,10 +87,10 @@ class ResourceLimitsServerNotices(object): logger.warning("Failed to get server notices room") return - yield self._check_and_set_tags(user_id, room_id) + await self._check_and_set_tags(user_id, room_id) # Determine current state of room - currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) + currently_blocked, ref_events = await self._is_room_currently_blocked(room_id) limit_msg = None limit_type = None @@ -99,7 +98,7 @@ class ResourceLimitsServerNotices(object): # Normally should always pass in user_id to check_auth_blocking # if you have it, but in this case are checking what would happen # to other users if they were to arrive. - yield self._auth.check_auth_blocking() + await self._auth.check_auth_blocking() except ResourceLimitError as e: limit_msg = e.msg limit_type = e.limit_type @@ -112,22 +111,21 @@ class ResourceLimitsServerNotices(object): # We have hit the MAU limit, but MAU alerting is disabled: # reset room if necessary and return if currently_blocked: - self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) return if currently_blocked and not limit_msg: # Room is notifying of a block, when it ought not to be. - yield self._remove_limit_block_notification(user_id, ref_events) + await self._remove_limit_block_notification(user_id, ref_events) elif not currently_blocked and limit_msg: # Room is not notifying of a block, when it ought to be. - yield self._apply_limit_block_notification( + await self._apply_limit_block_notification( user_id, limit_msg, limit_type ) except SynapseError as e: logger.error("Error sending resource limits server notice: %s", e) - @defer.inlineCallbacks - def _remove_limit_block_notification(self, user_id, ref_events): + async def _remove_limit_block_notification(self, user_id, ref_events): """Utility method to remove limit block notifications from the server notices room. @@ -137,12 +135,13 @@ class ResourceLimitsServerNotices(object): limit blocking and need to be preserved. """ content = {"pinned": ref_events} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) - @defer.inlineCallbacks - def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): + async def _apply_limit_block_notification( + self, user_id, event_body, event_limit_type + ): """Utility method to apply limit block notifications in the server notices room. @@ -159,12 +158,12 @@ class ResourceLimitsServerNotices(object): "admin_contact": self._config.admin_contact, "limit_type": event_limit_type, } - event = yield self._server_notices_manager.send_notice( + event = await self._server_notices_manager.send_notice( user_id, content, EventTypes.Message ) content = {"pinned": [event.event_id]} - yield self._server_notices_manager.send_notice( + await self._server_notices_manager.send_notice( user_id, content, EventTypes.Pinned, "" ) @@ -198,7 +197,7 @@ class ResourceLimitsServerNotices(object): room_id(str): The room id of the server notices room Returns: - + Deferred[Tuple[bool, List]]: bool: Is the room currently blocked list: The list of pinned events that are unrelated to limit blocking This list can be used as a convenience in the case where the block diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index bf0943f265..999c621b92 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -14,11 +14,9 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.types import UserID, create_requester -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -51,8 +49,7 @@ class ServerNoticesManager(object): """ return self._config.server_notices_mxid is not None - @defer.inlineCallbacks - def send_notice( + async def send_notice( self, user_id, event_content, type=EventTypes.Message, state_key=None ): """Send a notice to the given user @@ -68,8 +65,8 @@ class ServerNoticesManager(object): Returns: Deferred[FrozenEvent] """ - room_id = yield self.get_or_create_notice_room_for_user(user_id) - yield self.maybe_invite_user_to_room(user_id, room_id) + room_id = await self.get_or_create_notice_room_for_user(user_id) + await self.maybe_invite_user_to_room(user_id, room_id) system_mxid = self._config.server_notices_mxid requester = create_requester(system_mxid) @@ -86,13 +83,13 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = yield self._event_creation_handler.create_and_send_nonmember_event( + res = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) return res - @cachedInlineCallbacks() - def get_or_create_notice_room_for_user(self, user_id): + @cached() + async def get_or_create_notice_room_for_user(self, user_id): """Get the room for notices for a given user If we have not yet created a notice room for this user, create it, but don't @@ -109,7 +106,7 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in rooms: @@ -118,7 +115,7 @@ class ServerNoticesManager(object): # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = yield self._store.get_users_in_room(room.room_id) + user_ids = await self._store.get_users_in_room(room.room_id) if self.server_notices_mxid in user_ids: # we found a room which our user shares with the system notice # user @@ -146,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, @@ -158,7 +155,7 @@ class ServerNoticesManager(object): ) room_id = info["room_id"] - max_id = yield self._store.add_tag_to_room( + max_id = await self._store.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) @@ -166,8 +163,7 @@ class ServerNoticesManager(object): logger.info("Created server notices room %s for %s", room_id, user_id) return room_id - @defer.inlineCallbacks - def maybe_invite_user_to_room(self, user_id: str, room_id: str): + async def maybe_invite_user_to_room(self, user_id: str, room_id: str): """Invite the given user to the given server room, unless the user has already joined or been invited to it. @@ -179,14 +175,14 @@ class ServerNoticesManager(object): # Check whether the user has already joined or been invited to this room. If # that's the case, there is no need to re-invite them. - joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( + joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) for room in joined_rooms: if room.room_id == room_id: return - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( requester=requester, target=UserID.from_string(user_id), room_id=room_id, diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py index 652bab58e3..be74e86641 100644 --- a/synapse/server_notices/server_notices_sender.py +++ b/synapse/server_notices/server_notices_sender.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, @@ -36,18 +34,16 @@ class ServerNoticesSender(object): ResourceLimitsServerNotices(hs), ) - @defer.inlineCallbacks - def on_user_syncing(self, user_id): + async def on_user_syncing(self, user_id): """Called when the user performs a sync operation. Args: user_id (str): mxid of user who synced """ for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) - @defer.inlineCallbacks - def on_user_ip(self, user_id): + async def on_user_ip(self, user_id): """Called on the master when a worker process saw a client request. Args: @@ -57,4 +53,4 @@ class ServerNoticesSender(object): # we check for notices to send to the user in on_user_ip as well as # in on_user_syncing for sn in self._server_notices: - yield sn.maybe_send_server_notice_to_user(user_id) + await sn.maybe_send_server_notice_to_user(user_id) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 3e53c8568a..efcdd2100b 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="delete_account_validity_for_user", ) - @defer.inlineCallbacks - def is_server_admin(self, user): + async def is_server_admin(self, user): """Determines if a user is an admin of this homeserver. Args: @@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self.db.simple_select_one_onecol( + res = await self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index be665262c6..8aa56f1496 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), "Frank Jr.", ) # Set displayname again - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank" + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) ) self.assertEquals( @@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase): ) # Setting displayname a second time is forbidden - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) yield self.assertFailure(d, SynapseError) @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + ) ) yield self.assertFailure(d, AuthError) @@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) self.assertEquals( @@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase): ) # Set avatar again - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/me.png", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) ) self.assertEquals( @@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase): ) # Set avatar a second time is forbidden - d = self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + d = defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index f1dc51d6c9..1b7935cef2 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=False) + self.store.is_real_user = Mock(return_value=defer.succeed(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=1) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(1)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=2) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(2)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) - @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, password_hash=None): + async def get_or_create_user( + self, requester, localpart, displayname, password_hash=None + ): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth().check_auth_blocking() need_register = True try: - yield self.handler.check_username(localpart) + await self.handler.check_username(localpart) except SynapseError as e: if e.errcode == Codes.USER_IN_USE: need_register = False @@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): token = self.macaroon_generator.generate_access_token(user_id) if need_register: - yield self.handler.register_with_store( + await self.handler.register_with_store( user_id=user_id, password_hash=password_hash, create_profile_with_displayname=user.localpart, ) else: - yield defer.ensureDeferred( - self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - ) + await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None ) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) - yield self.hs.get_profile_handler().set_displayname( + await self.hs.get_profile_handler().set_displayname( user, requester, displayname, by_admin=True ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 93eb053b8c..987addad9b 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -55,25 +55,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) ) - self._send_notice = self._rlsn._server_notices_manager.send_notice - self._rlsn._server_notices_manager.send_notice = Mock() - self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_events = Mock(return_value=defer.succeed({})) - + self._rlsn._server_notices_manager.send_notice = Mock( + return_value=defer.succeed(Mock()) + ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.hs.config.limit_usage_by_mau = True self.user_id = "@user_id:test" - # self.server_notices_mxid = "@server:test" - # self.server_notices_mxid_display_name = None - # self.server_notices_mxid_avatar_url = None - # self.server_notices_room_name = "Server Notices" - self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - returnValue="" + return_value=defer.succeed("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock() + self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) self._rlsn._store.get_tags_for_room = Mock(return_value={}) self.hs.config.admin_contact = "mailto:user@test.com" @@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() @@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) mock_event = Mock( @@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth.check_auth_blocking = Mock( - side_effect=ResourceLimitError(403, "foo") + return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) @@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): an alert message is not sent into the room """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) - self.assertTrue(self._send_notice.call_count == 0) + self.assertEqual(self._send_notice.call_count, 0) def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ self.hs.config.mau_limit_alerting = False + self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED - ) + ), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self.hs.config.mau_limit_alerting = False self._rlsn._auth.check_auth_blocking = Mock( + return_value=defer.succeed(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER - ) + ), ) + self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( return_value=defer.succeed((True, [])) ) @@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock(return_value=1000) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(1000) + ) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/test_federation.py b/tests/test_federation.py index 9b5cf562f3..f297de95f1 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = room_creator.create_room( - our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + room = ensureDeferred( + room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) ) self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] -- cgit 1.5.1 From d9b8d274949df7356e880a67d3aac1b25613ab1f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 7 May 2020 11:35:23 +0200 Subject: Add a configuration setting for the dummy event threshold (#7422) Add dummy_events_threshold which allows configuring the number of forward extremities a room needs for Synapse to send forward extremities in it. --- changelog.d/7422.feature | 1 + docs/sample_config.yaml | 12 ++++++++++++ synapse/config/server.py | 15 +++++++++++++++ synapse/handlers/message.py | 4 +++- 4 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7422.feature (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7422.feature b/changelog.d/7422.feature new file mode 100644 index 0000000000..d6d5bb2169 --- /dev/null +++ b/changelog.d/7422.feature @@ -0,0 +1 @@ +Add a configuration setting to tweak the threshold for dummy events. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index fc970986c6..98ead7dc0e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -253,6 +253,18 @@ listeners: # bind_addresses: ['::1', '127.0.0.1'] # type: manhole +# Forward extremities can build up in a room due to networking delays between +# homeservers. Once this happens in a large room, calculation of the state of +# that room can become quite expensive. To mitigate this, once the number of +# forward extremities reaches a given threshold, Synapse will send an +# org.matrix.dummy_event event, which will reduce the forward extremities +# in the room. +# +# This setting defines the threshold (i.e. number of forward extremities in the +# room) at which dummy events are sent. The default value is 10. +# +#dummy_events_threshold: 5 + ## Homeserver blocking ## diff --git a/synapse/config/server.py b/synapse/config/server.py index c6d58effd4..6d88231843 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -505,6 +505,9 @@ class ServerConfig(Config): "cleanup_extremities_with_dummy_events", True ) + # The number of forward extremities in a room needed to send a dummy event. + self.dummy_events_threshold = config.get("dummy_events_threshold", 10) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) # Inhibits the /requestToken endpoints from returning an error that might leak @@ -823,6 +826,18 @@ class ServerConfig(Config): # bind_addresses: ['::1', '127.0.0.1'] # type: manhole + # Forward extremities can build up in a room due to networking delays between + # homeservers. Once this happens in a large room, calculation of the state of + # that room can become quite expensive. To mitigate this, once the number of + # forward extremities reaches a given threshold, Synapse will send an + # org.matrix.dummy_event event, which will reduce the forward extremities + # in the room. + # + # This setting defines the threshold (i.e. number of forward extremities in the + # room) at which dummy events are sent. The default value is 10. + # + #dummy_events_threshold: 5 + ## Homeserver blocking ## diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a324f09340..a622a600b4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -419,6 +419,8 @@ class EventCreationHandler(object): self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._dummy_events_threshold = hs.config.dummy_events_threshold + @defer.inlineCallbacks def create_event( self, @@ -1085,7 +1087,7 @@ class EventCreationHandler(object): """ self._expire_rooms_to_exclude_from_dummy_event_insertion() room_ids = await self.store.get_rooms_with_many_extremities( - min_count=10, + min_count=self._dummy_events_threshold, limit=5, room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), ) -- cgit 1.5.1 From 1124111a12c3ab35f8b68d9031695aec8b2c7c50 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 13 May 2020 17:15:40 +0100 Subject: Allow censoring of events to happen on workers. (#7492) This is safe as we can now write to cache invalidation stream on workers, and is required for when we move event persistence off master. --- changelog.d/7492.misc | 1 + synapse/app/generic_worker.py | 2 ++ synapse/handlers/message.py | 2 -- synapse/storage/data_stores/main/censor_events.py | 7 +------ 4 files changed, 4 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7492.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7492.misc b/changelog.d/7492.misc new file mode 100644 index 0000000000..5ad31819d6 --- /dev/null +++ b/changelog.d/7492.misc @@ -0,0 +1 @@ +Allow censoring of events to happen on workers. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 667ad20428..bccb1140b2 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -122,6 +122,7 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer +from synapse.storage.data_stores.main.censor_events import CensorEventsStore from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, @@ -442,6 +443,7 @@ class GenericWorkerSlavedStore( SlavedGroupServerStore, SlavedAccountDataStore, SlavedPusherStore, + CensorEventsStore, SlavedEventStore, SlavedKeyStore, RoomStore, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a622a600b4..0242521cc6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -72,7 +72,6 @@ class MessageHandler(object): self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages - self._is_worker_app = bool(hs.config.worker_app) # The scheduled call to self._expire_event. None if no call is currently # scheduled. @@ -260,7 +259,6 @@ class MessageHandler(object): Args: event (EventBase): The event to schedule the expiry of. """ - assert not self._is_worker_app expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) if not isinstance(expiry_ts, int) or event.is_state(): diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py index 49683b4d5a..2d48261724 100644 --- a/synapse/storage/data_stores/main/censor_events.py +++ b/synapse/storage/data_stores/main/censor_events.py @@ -33,15 +33,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class CensorEventsStore(CacheInvalidationWorkerStore, EventsWorkerStore, SQLBaseStore): +class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) - # This should only exist on master for now - assert ( - hs.config.worker.worker_app is None - ), "Can only instantiate CensorEventsStore on master" - def _censor_redactions(): return run_as_background_process( "_censor_redactions", self._censor_redactions -- cgit 1.5.1 From 250f3eb991f129bbd12e4fce58a9d4124aabd41e Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Tue, 19 May 2020 04:31:25 -0500 Subject: Omit displayname or avatar_url if they aren't set instead of returning null (#7497) Per https://github.com/matrix-org/matrix-doc/issues/1436#issuecomment-410089470 they should be omitted instead of returning null or "". They aren't marked as required in the spec. Fixes https://github.com/matrix-org/synapse/issues/7333 Signed-off-by: Aaron Raimist --- changelog.d/7497.bugfix | 1 + synapse/handlers/message.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7497.bugfix (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7497.bugfix b/changelog.d/7497.bugfix new file mode 100644 index 0000000000..3c7920cb10 --- /dev/null +++ b/changelog.d/7497.bugfix @@ -0,0 +1 @@ +When sending `m.room.member` events, omit `displayname` and `avatar_url` if they aren't set instead of setting them to `null`. Contributed by Aaron Raimist. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0242521cc6..8f362896a2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -484,9 +484,13 @@ class EventCreationHandler(object): try: if "displayname" not in content: - content["displayname"] = yield profile.get_displayname(target) + displayname = yield profile.get_displayname(target) + if displayname is not None: + content["displayname"] = displayname if "avatar_url" not in content: - content["avatar_url"] = yield profile.get_avatar_url(target) + avatar_url = yield profile.get_avatar_url(target) + if avatar_url is not None: + content["avatar_url"] = avatar_url except Exception as e: logger.info( "Failed to get profile information for %r: %s", target, e -- cgit 1.5.1 From 1531b214fc57714c14046a8f66c7b5fe5ec5dcdd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 14:21:54 +0100 Subject: Add ability to wait for replication streams (#7542) The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room). Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on. People probably want to look at this commit by commit. --- changelog.d/7542.misc | 1 + synapse/handlers/federation.py | 33 ++++++--- synapse/handlers/message.py | 36 ++++++--- synapse/handlers/room.py | 65 +++++++++++----- synapse/handlers/room_member.py | 65 ++++++++++------ synapse/handlers/room_member_worker.py | 11 +-- synapse/replication/http/federation.py | 13 +++- synapse/replication/http/membership.py | 14 ++-- synapse/replication/http/send_event.py | 4 +- synapse/replication/http/streams.py | 5 +- synapse/replication/tcp/client.py | 90 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 10 ++- synapse/rest/client/v1/room.py | 20 +++-- synapse/rest/client/v2_alpha/relations.py | 2 +- synapse/server.py | 5 ++ synapse/server.pyi | 5 ++ synapse/server_notices/server_notices_manager.py | 6 +- synapse/storage/data_stores/main/events_worker.py | 6 +- synapse/storage/data_stores/main/roommember.py | 2 + tests/federation/test_complexity.py | 8 +- tests/handlers/test_typing.py | 5 +- tests/storage/test_cleanup_extrems.py | 4 +- tests/storage/test_event_metrics.py | 2 +- tests/test_federation.py | 4 +- 24 files changed, 304 insertions(+), 112 deletions(-) create mode 100644 changelog.d/7542.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7542.misc b/changelog.d/7542.misc new file mode 100644 index 0000000000..7dd9b4823b --- /dev/null +++ b/changelog.d/7542.misc @@ -0,0 +1 @@ +Add ability to wait for replication streams. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bb03cc9add..e354c803db 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,6 +126,7 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._replication = hs.get_replication_data_handler() self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( hs @@ -1221,7 +1222,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict - ) -> None: + ) -> Tuple[str, int]: """ Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. @@ -1304,15 +1305,23 @@ class FederationHandler(BaseHandler): room_id=room_id, room_version=room_version_obj, ) - await self._persist_auth_tree( + max_stream_id = await self._persist_auth_tree( origin, auth_chain, state, event, room_version_obj ) + # We wait here until this instance has seen the events come down + # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + "master", "events", max_stream_id + ) + # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = await self.store.get_room_predecessor(room_id) if not predecessor or not isinstance(predecessor.get("room_id"), str): - return + return event.event_id, max_stream_id old_room_id = predecessor["room_id"] logger.debug( "Found predecessor for %s during remote join: %s", room_id, old_room_id @@ -1325,6 +1334,7 @@ class FederationHandler(BaseHandler): ) logger.debug("Finished joining %s to %s", joinee, room_id) + return event.event_id, max_stream_id finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] @@ -1554,7 +1564,7 @@ class FederationHandler(BaseHandler): async def do_remotely_reject_invite( self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict - ) -> EventBase: + ) -> Tuple[EventBase, int]: origin, event, room_version = await self._make_and_verify_event( target_hosts, room_id, user_id, "leave", content=content ) @@ -1574,9 +1584,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(target_hosts, event) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify([(event, context)]) - return event + return event, stream_id async def _make_and_verify_event( self, @@ -1888,7 +1898,7 @@ class FederationHandler(BaseHandler): state: List[EventBase], event: EventBase, room_version: RoomVersion, - ) -> None: + ) -> int: """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -1982,7 +1992,7 @@ class FederationHandler(BaseHandler): event, old_state=state ) - await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify([(event, new_event_context)]) async def _prep_event( self, @@ -2835,7 +2845,7 @@ class FederationHandler(BaseHandler): self, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> None: + ) -> int: """Persists events and tells the notifier/pushers about them, if necessary. @@ -2845,11 +2855,12 @@ class FederationHandler(BaseHandler): backfilling or not """ if self.config.worker_app: - await self._send_events_to_master( + result = await self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, ) + return result["max_stream_id"] else: max_stream_id = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled @@ -2864,6 +2875,8 @@ class FederationHandler(BaseHandler): for event, _ in event_and_contexts: await self._notify_persisted_event(event, max_stream_id) + return max_stream_id + async def _notify_persisted_event( self, event: EventBase, max_stream_id: int ) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f362896a2..f445e2aa2a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple from six import iteritems, itervalues, string_types @@ -42,6 +42,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -630,7 +631,9 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - async def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event( + self, requester, event, context, ratelimit=True + ) -> int: """ Persists and notifies local clients and federation of an event. @@ -639,6 +642,9 @@ class EventCreationHandler(object): context (Context) the context of the event. ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. + + Return: + The stream_id of the persisted event. """ if event.type == EventTypes.Member: raise SynapseError( @@ -659,7 +665,7 @@ class EventCreationHandler(object): ) return prev_state - await self.handle_new_client_event( + return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -688,7 +694,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None - ): + ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -711,10 +717,10 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - await self.send_nonmember_event( + stream_id = await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) - return event + return event, stream_id @measure_func("create_new_client_event") @defer.inlineCallbacks @@ -774,7 +780,7 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -787,6 +793,9 @@ class EventCreationHandler(object): context (EventContext) ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event + + Return: + The stream_id of the persisted event. """ if event.is_state() and (event.type, event.state_key) == ( @@ -827,7 +836,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - await self.send_event_to_master( + result = await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -836,14 +845,17 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + stream_id = result["stream_id"] + event.internal_metadata.stream_ordering = stream_id success = True - return + return stream_id - await self.persist_and_notify_client_event( + stream_id = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True + return stream_id finally: if not success: # Ensure that we actually remove the entries in the push actions @@ -886,7 +898,7 @@ class EventCreationHandler(object): async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1076,6 +1088,8 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + return event_stream_id + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 13850ba672..2698a129ca 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,6 +22,7 @@ import logging import math import string from collections import OrderedDict +from typing import Tuple from six import iteritems, string_types @@ -518,7 +519,7 @@ class RoomCreationHandler(BaseHandler): async def create_room( self, requester, config, ratelimit=True, creator_join_profile=None - ): + ) -> Tuple[dict, int]: """ Creates a new room. Args: @@ -535,9 +536,9 @@ class RoomCreationHandler(BaseHandler): `avatar_url` and/or `displayname`. Returns: - Deferred[dict]: - a dict containing the keys `room_id` and, if an alias was - requested, `room_alias`. + First, a dict containing the keys `room_id` and, if an alias + was, requested, `room_alias`. Secondly, the stream_id of the + last persisted event. Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. @@ -669,7 +670,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - await self._send_events_for_new_room( + last_stream_id = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -683,7 +684,10 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -697,7 +701,10 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -715,7 +722,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - await self.room_member_handler.update_membership( + _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -729,7 +736,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - await self.hs.get_room_member_handler().do_3pid_invite( + last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -745,7 +752,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - return result + return result, last_stream_id async def _send_events_for_new_room( self, @@ -758,7 +765,13 @@ class RoomCreationHandler(BaseHandler): room_alias=None, power_level_content_override=None, # Doesn't apply when initial state has power level state event content creator_join_profile=None, - ): + ) -> int: + """Sends the initial events into a new room. + + Returns: + The stream_id of the last event persisted. + """ + def create(etype, content, **kwargs): e = {"type": etype, "content": content} @@ -767,12 +780,16 @@ class RoomCreationHandler(BaseHandler): return e - async def send(etype, content, **kwargs): + async def send(etype, content, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) + return last_stream_id config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -797,7 +814,9 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - await send(etype=EventTypes.PowerLevels, content=pl_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=pl_content + ) else: power_level_content = { "users": {creator_id: 100}, @@ -830,33 +849,39 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - await send(etype=EventTypes.PowerLevels, content=power_level_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=power_level_content + ) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - await send(etype=etype, state_key=state_key, content=content) + last_sent_stream_id = await send( + etype=etype, state_key=state_key, content=content + ) + + return last_sent_stream_id async def _generate_room_id( self, creator_id: str, is_public: str, room_version: RoomVersion, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e51e1c32fe..691b6705b2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple from six.moves import http_client @@ -84,7 +84,7 @@ class RoomMemberHandler(object): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Try and join a room that this server is not in Args: @@ -104,7 +104,7 @@ class RoomMemberHandler(object): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -154,7 +154,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> EventBase: + ) -> Tuple[str, int]: user_id = target.to_string() if content is None: @@ -187,9 +187,10 @@ class RoomMemberHandler(object): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return duplicate + _, stream_id = await self.store.get_event_ordering(duplicate.event_id) + return duplicate.event_id, stream_id - await self.event_creation_handler.handle_new_client_event( + stream_id = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) @@ -213,7 +214,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) - return event + return event.event_id, stream_id async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id @@ -263,7 +264,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -294,7 +295,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: content_specified = bool(content) if content is None: content = {} @@ -398,7 +399,13 @@ class RoomMemberHandler(object): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - return old_state + _, stream_id = await self.store.get_event_ordering( + old_state.event_id + ) + return ( + old_state.event_id, + stream_id, + ) if old_membership in ["ban", "leave"] and action == "kick": raise AuthError(403, "The target user is not in the room") @@ -705,7 +712,7 @@ class RoomMemberHandler(object): requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -737,11 +744,11 @@ class RoomMemberHandler(object): ) if invitee: - await self.update_membership( + _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - await self._make_and_store_3pid_invite( + stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -752,6 +759,8 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) + return stream_id + async def _make_and_store_3pid_invite( self, requester: Requester, @@ -762,7 +771,7 @@ class RoomMemberHandler(object): user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -817,7 +826,10 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -835,6 +847,7 @@ class RoomMemberHandler(object): ratelimit=False, txn_id=txn_id, ) + return stream_id async def _is_host_in_room( self, current_state_ids: Dict[Tuple[str, str], str] @@ -916,7 +929,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> None: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -945,7 +958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) await self._user_joined_room(user, room_id) @@ -955,14 +968,14 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: if too_complex is False: # We checked, and we're under the limit. - return + return event_id, stream_id # Check again, but with the local state events too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. - return + return event_id, stream_id # The room is too large. Leave. requester = types.create_requester(user, None, False, None) @@ -975,6 +988,8 @@ class RoomMemberMasterHandler(RoomMemberHandler): errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) + return event_id, stream_id + async def _remote_reject_invite( self, requester: Requester, @@ -982,15 +997,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: - ret = await fed_handler.do_remotely_reject_invite( + event, stream_id = await fed_handler.do_remotely_reject_invite( remote_room_hosts, room_id, target.to_string(), content=content, ) - return ret + return event.event_id, stream_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -1000,8 +1015,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(target.to_string(), room_id) - return {} + stream_id = await self.store.locally_reject_invite( + target.to_string(), room_id + ) + return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 5c776cc0be..02e0c4103d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Optional +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._user_joined_room(user, room_id) - return ret + return ret["event_id"], ret["stream_id"] async def _remote_reject_invite( self, @@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ - return await self._remote_reject_client( + ret = await self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), content=content, ) + return ret["event_id"], ret["stream_id"] async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7e23b565b9..c287c4e269 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """Handles events newly received from federation, including persisting and - notifying. + notifying. Returns the maximum stream ID of the persisted events. The API looks like: @@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "context": { .. serialized event context .. }, }], "backfilled": false + } + + 200 OK + + { + "max_stream_id": 32443, + } """ NAME = "fed_send_events" @@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) - return 200, {} + return 200, {"max_stream_id": max_stream_id} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 3577611fd7..050fd34562 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) - return 200, {} + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = await self.federation_handler.do_remotely_reject_invite( + event, stream_id = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id, event_content, ) - ret = event.get_pdu_json() + event_id = event.event_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(user_id, room_id) - ret = {} + stream_id = await self.store.locally_reject_invite(user_id, room_id) + event_id = None - return 200, ret + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index b74b088ff4..c981723c1a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - await self.event_creation_handler.persist_and_notify_client_event( + stream_id = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {} + return 200, {"stream_id": stream_id} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index b705a8e16c..bde97eef32 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): super().__init__(hs) self._instance_name = hs.get_instance_name() - - # We pull the streams from the replication handler (if we try and make - # them ourselves we end up in an import loop). - self.streams = hs.get_tcp_replication().get_streams() + self.streams = hs.get_replication_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 28826302f5..508ad1b720 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,19 +14,23 @@ # limitations under the License. """A replication client for use by synapse workers. """ - +import heapq import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple +from twisted.internet.defer import Deferred from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.util.async_helpers import timeout_deferred +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,6 +39,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How long we allow callers to wait for replication updates before timing out. +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 + + class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. @@ -92,6 +100,16 @@ class ReplicationDataHandler: self.store = hs.get_datastore() self.pusher_pool = hs.get_pusherpool() self.notifier = hs.get_notifier() + self._reactor = hs.get_reactor() + self._clock = hs.get_clock() + self._streams = hs.get_replication_streams() + self._instance_name = hs.get_instance_name() + + # Map from stream to list of deferreds waiting for the stream to + # arrive at a particular position. The lists are sorted by stream position. + self._streams_to_waiters = ( + {} + ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -131,8 +149,76 @@ class ReplicationDataHandler: await self.pusher_pool.on_new_notifications(token, token) + # Notify any waiting deferreds. The list is ordered by position so we + # just iterate through the list until we reach a position that is + # greater than the received row position. + waiting_list = self._streams_to_waiters.get(stream_name, []) + + # Index of first item with a position after the current token, i.e we + # have called all deferreds before this index. If not overwritten by + # loop below means either a) no items in list so no-op or b) all items + # in list were called and so the list should be cleared. Setting it to + # `len(list)` works for both cases. + index_of_first_deferred_not_called = len(waiting_list) + + for idx, (position, deferred) in enumerate(waiting_list): + if position <= token: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + else: + # The list is sorted by position so we don't need to continue + # checking any futher entries in the list. + index_of_first_deferred_not_called = idx + break + + # Drop all entries in the waiting list that were called in the above + # loop. (This maintains the order so no need to resort) + waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" + + async def wait_for_stream_position( + self, instance_name: str, stream_name: str, position: int + ): + """Wait until this instance has received updates up to and including + the given stream position. + """ + + if instance_name == self._instance_name: + # We don't get told about updates written by this process, and + # anyway in that case we don't need to wait. + return + + current_position = self._streams[stream_name].current_token(self._instance_name) + if position <= current_position: + # We're already past the position + return + + # Create a new deferred that times out after N seconds, as we don't want + # to wedge here forever. + deferred = Deferred() + deferred = timeout_deferred( + deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + ) + + waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + + # We insert into the list using heapq as it is more efficient than + # pushing then resorting each time. + heapq.heappush(waiting_list, (position, deferred)) + + # We measure here to get in flight counts and average waiting time. + with Measure(self._clock, "repl.wait_for_stream_position"): + logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + await make_deferred_yieldable(deferred) + logger.info( + "Finished waiting for repl stream %r to reach %s", stream_name, position + ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7d40001988..0a13e1ed34 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self._replication = hs.get_replication_data_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet): message = content.get("message", self.DEFAULT_MESSAGE) room_name = content.get("room_name", "Content Violation Notification") - info = await self._room_creation_handler.create_room( + info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": "public_chat", @@ -94,6 +95,13 @@ class ShutdownRoomRestServlet(RestServlet): # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position("master", "events", stream_id) + users = await self.state.get_current_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6b5830cc3f..105e0cf4d2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request): requester = await self.auth.get_user_by_req(request) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = await self.room_member_handler.update_membership( + event_id, _ = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) + event_id = event.event_id ret = {} # type: dict - if event: - set_tag("event_id", event.event_id) - ret = {"event_id": event.event_id} + if event_id: + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 63f07b63da..89002ffbff 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) diff --git a/synapse/server.py b/synapse/server.py index c530f1aa1a..ca2deb49bb 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -210,6 +211,7 @@ class HomeServer(object): "storage", "replication_streamer", "replication_data_handler", + "replication_streams", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -583,6 +585,9 @@ class HomeServer(object): def build_replication_data_handler(self): return ReplicationDataHandler(self) + def build_replication_streams(self): + return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9e7fad7e6e..fe8024d2d4 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +from typing import Dict + import twisted.internet import synapse.api.auth @@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage from synapse.events.builder import EventBuilderFactory +from synapse.replication.tcp.streams import Stream class HomeServer(object): @property @@ -136,3 +139,5 @@ class HomeServer(object): pass def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: pass + def get_replication_streams(self) -> Dict[str, Stream]: + pass diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 999c621b92..bf2454c01c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -83,10 +83,10 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = await self._event_creation_handler.create_and_send_nonmember_event( + event, _ = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) - return res + return event @cached() async def get_or_create_notice_room_for_user(self, user_id): @@ -143,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9130b74eb5..b880a71782 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore): async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = await self._get_event_ordering(event_id1) - to_2, so_2 = await self._get_event_ordering(event_id2) + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): + def get_event_ordering(self, event_id): res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 1e9c850152..7c5ca81ae0 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1069,6 +1069,8 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) + return stream_ordering + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 94980733c4..0c9987be54 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) d = handler._remote_join( None, @@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) # Artificially raise the complexity self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 51e2b37218..2fa8d4739b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): reactor.pump((1000,)) hs = self.setup_test_homeserver( - notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring + notifier=Mock(), + http_client=mock_federation_client, + keyring=mock_keyring, + replication_streams={}, ) hs.datastores = datastores diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 0e04b2cf92..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): @@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b7fd36d3..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None diff --git a/tests/test_federation.py b/tests/test_federation.py index 13ff14863e..c5099dd039 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -28,13 +28,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = ensureDeferred( + room_deferred = ensureDeferred( room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False ) ) self.reactor.advance(0.1) - self.room_id = self.successResultOf(room)["room_id"] + self.room_id = self.successResultOf(room_deferred)[0]["room_id"] self.store = self.homeserver.get_datastore() -- cgit 1.5.1 From e5c67d04dbe5ed45d659e826a5dfcd5044a4e374 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 16:11:35 +0100 Subject: Add option to move event persistence off master (#7517) --- changelog.d/7517.feature | 1 + scripts/synapse_port_db | 3 + synapse/app/generic_worker.py | 53 +++++++++- synapse/config/logger.py | 1 + synapse/config/workers.py | 30 +++++- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 12 ++- synapse/handlers/presence.py | 6 ++ synapse/handlers/room.py | 7 ++ synapse/handlers/room_member.py | 39 ++++++-- synapse/replication/http/__init__.py | 4 +- synapse/replication/http/_base.py | 3 + synapse/replication/http/membership.py | 40 +++++++- synapse/replication/http/presence.py | 116 ++++++++++++++++++++++ synapse/replication/tcp/handler.py | 10 ++ synapse/rest/admin/rooms.py | 11 +- synapse/storage/data_stores/__init__.py | 6 +- synapse/storage/data_stores/main/devices.py | 30 ++++-- synapse/storage/data_stores/main/events.py | 34 ++++++- synapse/storage/data_stores/main/events_worker.py | 2 +- synapse/storage/data_stores/main/roommember.py | 25 ----- synapse/storage/persist_events.py | 6 ++ 22 files changed, 382 insertions(+), 73 deletions(-) create mode 100644 changelog.d/7517.feature create mode 100644 synapse/replication/http/presence.py (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7517.feature b/changelog.d/7517.feature new file mode 100644 index 0000000000..6062f0061c --- /dev/null +++ b/changelog.d/7517.feature @@ -0,0 +1 @@ +Add option to move event persistence off master. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index acd9ac4b75..9a0fbc61d8 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -196,6 +196,9 @@ class MockHomeserver: def get_reactor(self): return reactor + def get_instance_name(self): + return "master" + class Porter(object): def __init__(self, **kwargs): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index a37520000a..2906b93f6a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -39,7 +39,11 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer -from synapse.handlers.presence import BasePresenceHandler, get_interested_parties +from synapse.handlers.presence import ( + BasePresenceHandler, + PresenceState, + get_interested_parties, +) from synapse.http.server import JsonResource, OptionsResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite @@ -47,6 +51,10 @@ from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource +from synapse.replication.http.presence import ( + ReplicationBumpPresenceActiveTime, + ReplicationPresenceSetState, +) from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore @@ -247,6 +255,9 @@ class GenericWorkerPresence(BasePresenceHandler): # but we haven't notified the master of that yet self.users_going_offline = {} + self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) + self._set_state_client = ReplicationPresenceSetState.make_client(hs) + self._send_stop_syncing_loop = self.clock.looping_call( self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) @@ -304,10 +315,6 @@ class GenericWorkerPresence(BasePresenceHandler): self.users_going_offline.pop(user_id, None) self.send_user_sync(user_id, False, last_sync_ms) - def set_state(self, user, state, ignore_status_msg=False): - # TODO Hows this supposed to work? - return defer.succeed(None) - async def user_syncing( self, user_id: str, affect_presence: bool ) -> ContextManager[None]: @@ -386,6 +393,42 @@ class GenericWorkerPresence(BasePresenceHandler): if count > 0 ] + async def set_state(self, target_user, state, ignore_status_msg=False): + """Set the presence state of the user. + """ + presence = state["presence"] + + valid_presence = ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + ) + if presence not in valid_presence: + raise SynapseError(400, "Invalid presence state") + + user_id = target_user.to_string() + + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + await self._set_state_client( + user_id=user_id, state=state, ignore_status_msg=ignore_status_msg + ) + + async def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. + """ + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + user_id = user.to_string() + await self._bump_active_client(user_id=user_id) + class GenericWorkerTyping(object): def __init__(self, hs): diff --git a/synapse/config/logger.py b/synapse/config/logger.py index a25c70e928..49f6c32beb 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -257,5 +257,6 @@ def setup_logging( logging.warning("***** STARTING SERVER *****") logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) + logging.info("Instance name: %s", hs.get_instance_name()) return logger diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c80c338584..ed06b91a54 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -15,7 +15,7 @@ import attr -from ._base import Config +from ._base import Config, ConfigError @attr.s @@ -27,6 +27,17 @@ class InstanceLocationConfig: port = attr.ib(type=int) +@attr.s +class WriterLocations: + """Specifies the instances that write various streams. + + Attributes: + events: The instance that writes to the event and backfill streams. + """ + + events = attr.ib(default="master", type=str) + + class WorkerConfig(Config): """The workers are processes run separately to the main synapse process. They have their own pid_file and listener configuration. They use the @@ -83,11 +94,26 @@ class WorkerConfig(Config): bind_addresses.append("") # A map from instance name to host/port of their HTTP replication endpoint. - instance_map = config.get("instance_map", {}) or {} + instance_map = config.get("instance_map") or {} self.instance_map = { name: InstanceLocationConfig(**c) for name, c in instance_map.items() } + # Map from type of streams to source, c.f. WriterLocations. + writers = config.get("stream_writers") or {} + self.writers = WriterLocations(**writers) + + # Check that the configured writer for events also appears in + # `instance_map`. + if ( + self.writers.events != "master" + and self.writers.events not in self.instance_map + ): + raise ConfigError( + "Instance %r is configured to write events but does not appear in `instance_map` config." + % (self.writers.events,) + ) + def read_arguments(self, args): # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e354c803db..75ec90d267 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,11 +126,10 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() - self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( - hs - ) + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( hs ) @@ -1243,6 +1242,10 @@ class FederationHandler(BaseHandler): content: The event content to use for the join event. """ + # TODO: We should be able to call this on workers, but the upgrading of + # room stuff after join currently doesn't work on workers. + assert self.config.worker.worker_app is None + logger.debug("Joining %s to %s", joinee, room_id) origin, event, room_version_obj = await self._make_and_verify_event( @@ -1314,7 +1317,7 @@ class FederationHandler(BaseHandler): # # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - "master", "events", max_stream_id + self.config.worker.writers.events, "events", max_stream_id ) # Check whether this room is the result of an upgrade of a room we already know @@ -2854,8 +2857,9 @@ class FederationHandler(BaseHandler): backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker_app: - result = await self._send_events_to_master( + if self.config.worker.writers.events != self._instance_name: + result = await self._send_events( + instance_name=self.config.worker.writers.events, store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f445e2aa2a..ea25f0515a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -366,10 +366,11 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types - self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) + self.send_event = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users self.base_handler = BaseHandler(hs) @@ -835,8 +836,9 @@ class EventCreationHandler(object): success = False try: # If we're a worker we need to hit out to the master. - if self.config.worker_app: - result = await self.send_event_to_master( + if self.config.worker.writers.events != self._instance_name: + result = await self.send_event( + instance_name=self.config.worker.writers.events, event_id=event.event_id, store=self.store, requester=requester, @@ -902,9 +904,9 @@ class EventCreationHandler(object): """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. - This should only be run on master. + This should only be run on the instance in charge of persisting events. """ - assert not self.config.worker_app + assert self.config.worker.writers.events == self._instance_name if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9ea11c0754..3594f3b00f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -193,6 +193,12 @@ class BasePresenceHandler(abc.ABC): ) -> None: """Set the presence state of the user. """ + @abc.abstractmethod + async def bump_presence_active_time(self, user: UserID): + """We've seen the user do something that indicates they're interacting + with the app. + """ + class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "synapse.server.HomeServer"): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2698a129ca..61db3ccc43 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -89,6 +89,8 @@ class RoomCreationHandler(BaseHandler): self.room_member_handler = hs.get_room_member_handler() self.config = hs.config + self._replication = hs.get_replication_data_handler() + # linearizer to stop two upgrades happening at once self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") @@ -752,6 +754,11 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() + # Always wait for room creation to progate before returning + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", last_stream_id + ) + return result, last_stream_id async def _send_events_for_new_room( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 691b6705b2..0f7af982f0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -26,6 +26,9 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.replication.http.membership import ( + ReplicationLocallyRejectInviteRestServlet, +) from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -44,11 +47,6 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -71,6 +69,17 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles + self._event_stream_writer_instance = hs.config.worker.writers.events + self._is_on_event_persistence_instance = ( + self._event_stream_writer_instance == hs.get_instance_name() + ) + if self._is_on_event_persistence_instance: + self.persist_event_storage = hs.get_storage().persistence + else: + self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client( + hs + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -121,6 +130,22 @@ class RoomMemberHandler(object): """ raise NotImplementedError() + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + if self._is_on_event_persistence_instance: + return await self.persist_event_storage.locally_reject_invite( + user_id, room_id + ) + else: + result = await self._locally_reject_client( + instance_name=self._event_stream_writer_instance, + user_id=user_id, + room_id=room_id, + ) + return result["stream_id"] + @abc.abstractmethod async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the @@ -1015,9 +1040,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - stream_id = await self.store.locally_reject_invite( - target.to_string(), room_id - ) + stream_id = await self.locally_reject_invite(target.to_string(), room_id) return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a909744e93..19b69e0e11 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -19,6 +19,7 @@ from synapse.replication.http import ( federation, login, membership, + presence, register, send_event, streams, @@ -35,10 +36,11 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs): send_event.register_servlets(hs, self) federation.register_servlets(hs, self) + presence.register_servlets(hs, self) + membership.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: - membership.register_servlets(hs, self) login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index c3136a4eb9..793cef6c26 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -142,6 +142,7 @@ class ReplicationEndpoint(object): """ clock = hs.get_clock() client = hs.get_simple_http_client() + local_instance_name = hs.get_instance_name() master_host = hs.config.worker_replication_host master_port = hs.config.worker_replication_http_port @@ -151,6 +152,8 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks def send_request(instance_name="master", **kwargs): + if instance_name == local_instance_name: + raise Exception("Trying to send HTTP request to self") if instance_name == "master": host = master_host port = master_port diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 050fd34562..a7174c4a8f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -14,12 +14,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -106,6 +110,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() + self.member_handler = hs.get_room_member_handler() @staticmethod def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): @@ -149,12 +154,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - stream_id = await self.store.locally_reject_invite(user_id, room_id) + stream_id = await self.member_handler.locally_reject_invite( + user_id, room_id + ) event_id = None return 200, {"event_id": event_id, "stream_id": stream_id} +class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint): + """Rejects the invite for the user and room locally. + + Request format: + + POST /_synapse/replication/locally_reject_invite/:room_id/:user_id + + {} + """ + + NAME = "locally_reject_invite" + PATH_ARGS = ("room_id", "user_id") + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.member_handler = hs.get_room_member_handler() + + @staticmethod + def _serialize_payload(room_id, user_id): + return {} + + async def _handle_request(self, request, room_id, user_id): + logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id) + + stream_id = await self.member_handler.locally_reject_invite(user_id, room_id) + + return 200, {"stream_id": stream_id} + + class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): """Notifies that a user has joined or left the room @@ -208,3 +245,4 @@ def register_servlets(hs, http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) + ReplicationLocallyRejectInviteRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py new file mode 100644 index 0000000000..ea1b33331b --- /dev/null +++ b/synapse/replication/http/presence.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): + """We've seen the user do something that indicates they're interacting + with the app. + + The POST looks like: + + POST /_synapse/replication/bump_presence_active_time/ + + 200 OK + + {} + """ + + NAME = "bump_presence_active_time" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id): + return {} + + async def _handle_request(self, request, user_id): + await self._presence_handler.bump_presence_active_time( + UserID.from_string(user_id) + ) + + return ( + 200, + {}, + ) + + +class ReplicationPresenceSetState(ReplicationEndpoint): + """Set the presence state for a user. + + The POST looks like: + + POST /_synapse/replication/presence_set_state/ + + { + "state": { ... }, + "ignore_status_msg": false, + } + + 200 OK + + {} + """ + + NAME = "presence_set_state" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id, state, ignore_status_msg=False): + return { + "state": state, + "ignore_status_msg": ignore_status_msg, + } + + async def _handle_request(self, request, user_id): + content = parse_json_object_from_request(request) + + await self._presence_handler.set_state( + UserID.from_string(user_id), content["state"], content["ignore_status_msg"] + ) + + return ( + 200, + {}, + ) + + +def register_servlets(hs, http_server): + ReplicationBumpPresenceActiveTime(hs).register(http_server) + ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index acfa66a7a8..03300e5336 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -38,7 +38,9 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + BackfillStream, CachesStream, + EventsStream, FederationStream, Stream, ) @@ -87,6 +89,14 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) continue + if isinstance(stream, (EventsStream, BackfillStream)): + # Only add EventStream and BackfillStream as a source on the + # instance in charge of event persistence. + if hs.config.worker.writers.events == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 0a13e1ed34..8173baef8f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -100,7 +100,9 @@ class ShutdownRoomRestServlet(RestServlet): # we try and auto join below. # # TODO: Currently the events stream is written to from master - await self._replication.wait_for_stream_position("master", "events", stream_id) + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) users = await self.state.get_current_users_in_room(room_id) kicked_users = [] @@ -113,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet): try: target_requester = create_requester(user_id) - await self.room_member_handler.update_membership( + _, stream_id = await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=room_id, @@ -123,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet): require_consent=False, ) + # Wait for leave to come in over replication before trying to forget. + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + await self.room_member_handler.forget(target_requester.user, room_id) await self.room_member_handler.update_membership( diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 791961b296..599ee470d4 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -66,9 +66,9 @@ class DataStores(object): self.main = main_store_class(database, db_conn, hs) - # If we're on a process that can persist events (currently - # master), also instantiate a `PersistEventsStore` - if hs.config.worker.worker_app is None: + # If we're on a process that can persist events also + # instantiate a `PersistEventsStore` + if hs.config.worker.writers.events == hs.get_instance_name(): self.persist_events = PersistEventsStore( hs, database, self.main ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 0e8378714a..417ac8dc7c 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -689,6 +689,25 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + + def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + self.db.simple_delete_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_device_list_last_stream_id_for_remote, (user_id,) + ) + + return self.db.runInteraction( + "mark_remote_user_device_list_as_unsubscribed", + _mark_remote_user_device_list_as_unsubscribed_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): @@ -969,17 +988,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): desc="update_device", ) - @defer.inlineCallbacks - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. - """ - yield self.db.simple_delete( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - desc="mark_remote_user_device_list_as_unsubscribed", - ) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - def update_remote_device_list_cache_entry( self, user_id, device_id, content, stream_id ): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index a97f8b3934..a6572571b4 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -138,10 +138,10 @@ class PersistEventsStore: self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator - # This should only exist on master for now + # This should only exist on instances that are configured to write assert ( - hs.config.worker.worker_app is None - ), "Can only instantiate PersistEventsStore on master" + hs.config.worker.writers.events == hs.get_instance_name() + ), "Can only instantiate EventsStore on master" @_retry_on_integrity_error @defer.inlineCallbacks @@ -1590,3 +1590,31 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, (stream_ordering, True, room_id, user_id)) + + # We also clear this entry from `local_current_membership`. + # Ideally we'd point to a leave event, but we don't have one, so + # nevermind. + self.db.simple_delete_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + ) + + with self._stream_id_gen.get_next() as stream_ordering: + await self.db.runInteraction("locally_reject_invite", f, stream_ordering) + + return stream_ordering diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index b880a71782..213d69100a 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -76,7 +76,7 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker_app is None: + if hs.config.worker.writers.events == hs.get_instance_name(): # We are the process in charge of generating stream ids for events, # so instantiate ID generators based on the database self._stream_id_gen = StreamIdGenerator( diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 7c5ca81ae0..137ebac833 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1046,31 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - # We also clear this entry from `local_current_membership`. - # Ideally we'd point to a leave event, but we don't have one, so - # nevermind. - self.db.simple_delete_txn( - txn, - table="local_current_membership", - keyvalues={"room_id": room_id, "user_id": user_id}, - ) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) - - return stream_ordering - def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 12e1ffb9a2..f159400a87 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -786,3 +786,9 @@ class EventsPersistenceStorage(object): for user_id in left_users: await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + return await self.persist_events_store.locally_reject_invite(user_id, room_id) -- cgit 1.5.1 From f4269694ce51a04761e43fc5897d6f8fe0ea18cc Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 22 May 2020 21:47:07 +0100 Subject: Optimise some references to hs.config (#7546) These are surprisingly expensive, and we only really need to do them at startup. --- changelog.d/7546.misc | 1 + synapse/handlers/message.py | 8 +- .../resource_limits_server_notices.py | 15 ++- .../data_stores/main/monthly_active_users.py | 28 ++++-- .../test_resource_limits_server_notices.py | 60 ++++++----- tests/storage/test_client_ips.py | 13 +-- tests/storage/test_monthly_active_users.py | 110 +++++++++++---------- tests/test_mau.py | 63 ++++++------ 8 files changed, 162 insertions(+), 136 deletions(-) create mode 100644 changelog.d/7546.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7546.misc b/changelog.d/7546.misc new file mode 100644 index 0000000000..0a9704789b --- /dev/null +++ b/changelog.d/7546.misc @@ -0,0 +1 @@ +Optimise some references to `hs.config`. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ea25f0515a..681f92cafd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -366,7 +366,9 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._instance_name = hs.get_instance_name() + self._is_event_writer = ( + self.config.worker.writers.events == hs.get_instance_name() + ) self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -836,7 +838,7 @@ class EventCreationHandler(object): success = False try: # If we're a worker we need to hit out to the master. - if self.config.worker.writers.events != self._instance_name: + if not self._is_event_writer: result = await self.send_event( instance_name=self.config.worker.writers.events, event_id=event.event_id, @@ -906,7 +908,7 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self.config.worker.writers.events == self._instance_name + assert self._is_event_writer if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index d97166351e..73f2cedb5c 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -48,6 +48,12 @@ class ResourceLimitsServerNotices(object): self._notifier = hs.get_notifier() + self._enabled = ( + hs.config.limit_usage_by_mau + and self._server_notices_manager.is_enabled() + and not hs.config.hs_disabled + ) + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, this will be true in two cases. @@ -61,14 +67,7 @@ class ResourceLimitsServerNotices(object): Returns: Deferred """ - if self._config.hs_disabled is True: - return - - if self._config.limit_usage_by_mau is False: - return - - if not self._server_notices_manager.is_enabled(): - # Don't try and send server notices unless they've been enabled + if not self._enabled: return timestamp = await self._store.user_last_seen_monthly_active(user_id) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 925bc5691b..a42c52f323 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -122,6 +122,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: Database, db_conn, hs): super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_stats_only = hs.config.mau_stats_only + self._max_mau_value = hs.config.max_mau_value + # Do not add more reserved users than the total allowable number # cur = LoggingTransaction( self.db.new_transaction( @@ -130,7 +134,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): [], [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value], + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): @@ -142,6 +146,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): threepids (list[dict]): List of threepid dicts to reserve """ + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + for tp in threepids: user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) @@ -191,8 +204,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): txn.execute(sql, query_args) - max_mau_value = self.hs.config.max_mau_value - if self.hs.config.limit_usage_by_mau: + if self._limit_usage_by_mau: # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. # Note it is not possible to write this query using OFFSET due to @@ -210,13 +222,13 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): LIMIT ? ) """ - txn.execute(sql, (max_mau_value,)) + txn.execute(sql, ((self._max_mau_value),)) # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # when len(reserved_users) == 0. Works fine on sqlite. else: # Must be >= 0 for postgres num_of_non_reserved_users_to_remove = max( - max_mau_value - len(reserved_users), 0 + self._max_mau_value - len(reserved_users), 0 ) # It is important to filter reserved users twice to guard @@ -335,7 +347,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): Args: user_id(str): the user_id to query """ - if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only: + if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group is_guest = yield self.is_guest(user_id) if is_guest: @@ -356,11 +368,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # In the case where mau_stats_only is True and limit_usage_by_mau is # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. - if not self.hs.config.limit_usage_by_mau: + if not self._limit_usage_by_mau: yield self.upsert_monthly_active_user(user_id) else: count = yield self.get_monthly_active_count() - if count < self.hs.config.max_mau_value: + if count < self._max_mau_value: yield self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: yield self.upsert_monthly_active_user(user_id) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 406f29a7c0..99908edba3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,20 +27,33 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest +from tests.unittest import override_config +from tests.utils import default_config class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - hs_config = self.default_config() - hs_config["server_notices"] = { - "system_mxid_localpart": "server", - "system_mxid_display_name": "test display name", - "system_mxid_avatar_url": None, - "room_name": "Server Notices", - } + def default_config(self): + config = default_config("test") + + config.update( + { + "admin_contact": "mailto:user@test.com", + "limit_usage_by_mau": True, + "server_notices": { + "system_mxid_localpart": "server", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", + }, + } + ) + + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - hs = self.setup_test_homeserver(config=hs_config) - return hs + return config def prepare(self, reactor, clock, hs): self.server_notices_sender = self.hs.get_server_notices_sender() @@ -60,7 +73,6 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): ) self._send_notice = self._rlsn._server_notices_manager.send_notice - self.hs.config.limit_usage_by_mau = True self.user_id = "@user_id:test" self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( @@ -68,21 +80,17 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) - self.hs.config.admin_contact = "mailto:user@test.com" - - def test_maybe_send_server_notice_to_user_flag_off(self): - """Tests cases where the flags indicate nothing to do""" - # test hs disabled case - self.hs.config.hs_disabled = True + @override_config({"hs_disabled": True}) + def test_maybe_send_server_notice_disabled_hs(self): + """If the HS is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) - self._send_notice.assert_not_called() - # Test when mau limiting disabled - self.hs.config.hs_disabled = False - self.hs.config.limit_usage_by_mau = False - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + @override_config({"limit_usage_by_mau": False}) + def test_maybe_send_server_notice_to_user_flag_off(self): + """If mau limiting is disabled, we should not send notices""" + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): @@ -153,13 +161,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() + @override_config({"mau_limit_alerting": False}) def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): """ Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self.hs.config.mau_limit_alerting = False - self._rlsn._auth.check_auth_blocking = Mock( return_value=defer.succeed(None), side_effect=ResourceLimitError( @@ -170,12 +177,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.assertEqual(self._send_notice.call_count, 0) + @override_config({"mau_limit_alerting": False}) def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self.hs.config.mau_limit_alerting = False - self._rlsn._auth.check_auth_blocking = Mock( return_value=defer.succeed(None), side_effect=ResourceLimitError( @@ -187,12 +193,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): # Would be better to check contents, but 2 calls == set blocking event self.assertEqual(self._send_notice.call_count, 2) + @override_config({"mau_limit_alerting": False}) def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): """ When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self.hs.config.mau_limit_alerting = False self._rlsn._auth.check_auth_blocking = Mock( return_value=defer.succeed(None), side_effect=ResourceLimitError( diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index bf674dd184..3b483bc7f0 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.unittest import override_config class ClientIpStoreTestCase(unittest.HomeserverTestCase): @@ -137,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) def test_disabled_monthly_active_user(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success( self.store.insert_client_ip( @@ -149,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_full(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 lots_of_users = 100 user_id = "@user:server" @@ -166,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -184,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_updating_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index bc53bf0951..447fcb3a1c 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,94 +19,106 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 +def gen_3pids(count): + """Generate `count` threepids as a list.""" + return [ + {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count) + ] + + class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") + + config.update({"limit_usage_by_mau": True, "max_mau_value": 50}) - hs = self.setup_test_homeserver() - self.store = hs.get_datastore() - hs.config.limit_usage_by_mau = True - hs.config.max_mau_value = 50 + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) + return config + + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() # Advance the clock a bit reactor.advance(FORTY_DAYS) - return hs - + @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - self.hs.config.max_mau_value = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + + # register three users, of which two have reserved 3pids, and a third + # which is a support user. user1 = "@user1:server" - user1_email = "user1@matrix.org" + user1_email = threepids[0]["address"] user2 = "@user2:server" - user2_email = "user2@matrix.org" + user2_email = threepids[1]["address"] user3 = "@user3:server" - user3_email = "user3@matrix.org" - threepids = [ - {"medium": "email", "address": user1_email}, - {"medium": "email", "address": user2_email}, - {"medium": "email", "address": user3_email}, - ] - self.hs.config.mau_limits_reserved_threepids = threepids - # -1 because user3 is a support user and does not count - user_num = len(threepids) - 1 - - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.store.register_user( - user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT - ) + self.store.register_user(user_id=user1) + self.store.register_user(user_id=user2) + self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT) self.pump() now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) + # XXX why are we doing this here? this function is only run at startup + # so it is odd to re-run it here. self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() - active_count = self.store.get_monthly_active_count() + # the number of users we expect will be counted against the mau limit + # -1 because user3 is a support user and does not count + user_num = len(threepids) - 1 - # Test total counts, ensure user3 (support user) is not counted - self.assertEquals(self.get_success(active_count), user_num) + # Check the number of active users. Ensure user3 (support user) is not counted + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEquals(active_count, user_num) - # Test user is marked as active + # Test each of the registered users is marked as active timestamp = self.store.user_last_seen_monthly_active(user1) self.assertTrue(self.get_success(timestamp)) timestamp = self.store.user_last_seen_monthly_active(user2) self.assertTrue(self.get_success(timestamp)) - # Test that users are never removed from the db. + # Test that users with reserved 3pids are not removed from the MAU table + # XXX some of this is redundant. poking things into the config shouldn't + # work, and in any case it's not obvious what we expect to happen when + # we advance the reactor. self.hs.config.max_mau_value = 0 - self.reactor.advance(FORTY_DAYS) self.hs.config.max_mau_value = 5 - self.store.reap_monthly_active_users() self.pump() active_count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(active_count), user_num) - # Test that regular users are removed from the db + # Add some more users and check they are counted as active ru_count = 2 self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru2:server") self.pump() - active_count = self.store.get_monthly_active_count() self.assertEqual(self.get_success(active_count), user_num + ru_count) - self.hs.config.max_mau_value = user_num + + # now run the reaper and check that the number of active users is reduced + # to max_mau_value self.store.reap_monthly_active_users() self.pump() active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + self.assertEquals(self.get_success(active_count), 3) def test_can_insert_and_count_mau(self): count = self.store.get_monthly_active_count() @@ -136,8 +148,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): result = self.store.user_last_seen_monthly_active(user_id3) self.assertNotEqual(self.get_success(result), 0) + @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): - self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): self.store.upsert_monthly_active_user("@user%d:server" % i) @@ -158,19 +170,19 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(count), 0) + # Note that below says mau_limit (no s), this is the name of the config + # value, although it gets stored on the config object as mau_limits. + @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) def test_reap_monthly_active_users_reserved_users(self): """ Tests that reaping correctly handles reaping where reserved users are present""" - - self.hs.config.max_mau_value = 5 - initial_users = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + initial_users = len(threepids) reserved_user_number = initial_users - 1 - threepids = [] for i in range(initial_users): user = "@user%d:server" % i - email = "user%d@example.com" % i + email = "user%d@matrix.org" % i self.get_success(self.store.upsert_monthly_active_user(user)) - threepids.append({"medium": "email", "address": email}) # Need to ensure that the most recent entries in the # monthly_active_users table are reserved now = int(self.hs.get_clock().time_msec()) @@ -182,7 +194,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user, "email", email, now, now) ) - self.hs.config.mau_limits_reserved_threepids = threepids self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) @@ -279,11 +290,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.assertEqual(self.get_success(count), 0) + # Note that the max_mau_value setting should not matter. + @override_config( + {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} + ) def test_track_monthly_users_without_cap(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - self.hs.config.max_mau_value = 1 # should not matter - count = self.store.get_monthly_active_count() self.assertEqual(0, self.get_success(count)) @@ -294,9 +305,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEqual(2, self.get_success(count)) + @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = False self.store.upsert_monthly_active_user = Mock() self.store.populate_monthly_active_users("@user:sever") diff --git a/tests/test_mau.py b/tests/test_mau.py index eb159e3ba5..8a97f0998d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,47 +17,44 @@ import json -from mock import Mock - -from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.rest.client.v2_alpha import register, sync from tests import unittest +from tests.unittest import override_config +from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): servlets = [register.register_servlets, sync.register_servlets] - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") - self.hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() + config.update( + { + "registrations_require_3pid": [], + "limit_usage_by_mau": True, + "max_mau_value": 2, + "mau_trial_days": 0, + "server_notices": { + "system_mxid_localpart": "server", + "room_name": "Test Server Notice Room", + }, + } ) - self.store = self.hs.get_datastore() - - self.hs.config.registrations_require_3pid = [] - self.hs.config.enable_registration_captcha = False - self.hs.config.recaptcha_public_key = [] - - self.hs.config.limit_usage_by_mau = True - self.hs.config.hs_disabled = False - self.hs.config.max_mau_value = 2 - self.hs.config.server_notices_mxid = "@server:red" - self.hs.config.server_notices_mxid_display_name = None - self.hs.config.server_notices_mxid_avatar_url = None - self.hs.config.server_notices_room_name = "Test Server Notice Room" - self.hs.config.mau_trial_days = 0 + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - # AuthBlocking reads config options during hs creation. Recreate the - # hs' copy of AuthBlocking after we've updated config values above - self.auth_blocking = AuthBlocking(self.hs) - self.hs.get_auth()._auth_blocking = self.auth_blocking + return config - return self.hs + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated @@ -66,6 +63,9 @@ class TestMauLimit(unittest.HomeserverTestCase): token2 = self.create_user("kermit2") self.do_sync_for_user(token2) + # check we're testing what we think we are: there should be two active users + self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2) + # We've created and activated two users, we shouldn't be able to # register new users with self.assertRaises(SynapseError) as cm: @@ -93,9 +93,8 @@ class TestMauLimit(unittest.HomeserverTestCase): token3 = self.create_user("kermit3") self.do_sync_for_user(token3) + @override_config({"mau_trial_days": 1}) def test_trial_delay(self): - self.hs.config.mau_trial_days = 1 - # We should be able to register more than the limit initially token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -127,8 +126,8 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + @override_config({"mau_trial_days": 1}) def test_trial_users_cant_come_back(self): - self.auth_blocking._mau_trial_days = 1 self.hs.config.mau_trial_days = 1 # We should be able to register more than the limit initially @@ -176,11 +175,11 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + @override_config( + # max_mau_value should not matter + {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True} + ) def test_tracked_but_not_limited(self): - self.auth_blocking._max_mau_value = 1 # should not matter - self.auth_blocking._limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - # Simply being able to create 2 users indicates that the # limit was not reached. token1 = self.create_user("kermit1") -- cgit 1.5.1 From f4e6495b5d3267976f34088fa7459b388b801eb6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 5 Jun 2020 10:47:20 +0100 Subject: Performance improvements and refactor of Ratelimiter (#7595) While working on https://github.com/matrix-org/synapse/issues/5665 I found myself digging into the `Ratelimiter` class and seeing that it was both: * Rather undocumented, and * causing a *lot* of config checks This PR attempts to refactor and comment the `Ratelimiter` class, as well as encourage config file accesses to only be done at instantiation. Best to be reviewed commit-by-commit. --- changelog.d/7595.misc | 1 + synapse/api/ratelimiting.py | 153 +++++++++++++++++++++------- synapse/config/ratelimiting.py | 8 +- synapse/handlers/_base.py | 60 ++++++----- synapse/handlers/auth.py | 24 ++--- synapse/handlers/message.py | 1 - synapse/handlers/register.py | 9 +- synapse/rest/client/v1/login.py | 65 ++++-------- synapse/rest/client/v2_alpha/register.py | 16 +-- synapse/server.py | 17 ++-- synapse/util/ratelimitutils.py | 2 +- tests/api/test_ratelimiting.py | 96 +++++++++++++---- tests/handlers/test_profile.py | 6 +- tests/replication/slave/storage/_base.py | 9 +- tests/rest/client/v1/test_events.py | 8 +- tests/rest/client/v1/test_login.py | 49 +++++++-- tests/rest/client/v1/test_rooms.py | 9 +- tests/rest/client/v1/test_typing.py | 10 +- tests/rest/client/v2_alpha/test_register.py | 9 +- 19 files changed, 322 insertions(+), 230 deletions(-) create mode 100644 changelog.d/7595.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/7595.misc b/changelog.d/7595.misc new file mode 100644 index 0000000000..7a0646b1a3 --- /dev/null +++ b/changelog.d/7595.misc @@ -0,0 +1 @@ +Refactor `Ratelimiter` to limit the amount of expensive config value accesses. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 7a049b3af7..ec6b3a69a2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,75 +17,157 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.util import Clock class Ratelimiter(object): """ - Ratelimit message sending by user. + Ratelimit actions marked by arbitrary keys. + + Args: + clock: A homeserver clock, for retrieving the current time + rate_hz: The long term number of actions that can be performed in a second. + burst_count: How many actions that can be performed before being limited. """ - def __init__(self): - self.message_counts = ( - OrderedDict() - ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] + def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + self.clock = clock + self.rate_hz = rate_hz + self.burst_count = burst_count + + # A ordered dictionary keeping track of actions, when they were last + # performed and how often. Each entry is a mapping from a key of arbitrary type + # to a tuple representing: + # * How many times an action has occurred since a point in time + # * The point in time + # * The rate_hz of this particular entry. This can vary per request + self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] - def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + def can_do_action( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now. - rate_hz: The long term number of messages a user can send in a - second. - burst_count: How many messages the user can send before being - limited. - update (bool): Whether to update the message rates or not. This is - useful to check if a message would be allowed to be sent before - its ready to be actually sent. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + Returns: - A pair of a bool indicating if they can send a message now and a - time in seconds of when they can next send a message. + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero """ - self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.get( - key, (0.0, time_now_s, None) - ) + # Override default values if set + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count + + # Remove any expired entries + self._prune_message_counts(time_now_s) + + # Check if there is an existing count entry for this key + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + + # Check whether performing another action is allowed time_delta = time_now_s - time_start - sent_count = message_count - time_delta * rate_hz - if sent_count < 0: + performed_count = action_count - time_delta * rate_hz + if performed_count < 0: + # Allow, reset back to count 1 allowed = True time_start = time_now_s - message_count = 1.0 - elif sent_count > burst_count - 1.0: + action_count = 1.0 + elif performed_count > burst_count - 1.0: + # Deny, we have exceeded our burst count allowed = False else: + # We haven't reached our limit yet allowed = True - message_count += 1 + action_count += 1.0 if update: - self.message_counts[key] = (message_count, time_start, rate_hz) + self.actions[key] = (action_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = time_start + (message_count - burst_count + 1) / rate_hz + # Find out when the count of existing actions expires + time_allowed = time_start + (action_count - burst_count + 1) / rate_hz + + # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s + else: + # XXX: Why is this -1? This seems to only be used in + # self.ratelimit. I guess so that clients get a time in the past and don't + # feel afraid to try again immediately time_allowed = -1 return allowed, time_allowed - def prune_message_counts(self, time_now_s): - for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = self.message_counts[key] + def _prune_message_counts(self, time_now_s: int): + """Remove message count entries that have not exceeded their defined + rate_hz limit + + Args: + time_now_s: The current time + """ + # We create a copy of the key list here as the dictionary is modified during + # the loop + for key in list(self.actions.keys()): + action_count, time_start, rate_hz = self.actions[key] + + # Rate limit = "seconds since we started limiting this action" * rate_hz + # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start - if message_count - time_delta * rate_hz > 0: - break + if action_count - time_delta * rate_hz > 0: + continue else: - del self.message_counts[key] + del self.actions[key] + + def ratelimit( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ): + """Checks if an action can be performed. If not, raises a LimitExceededError + + Args: + key: An arbitrary key used to classify an action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Raises: + LimitExceededError: If an action could not be performed, along with the time in + milliseconds until the action can be performed again + """ + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True): allowed, time_allowed = self.can_do_action( - key, time_now_s, rate_hz, burst_count, update + key, + rate_hz=rate_hz, + burst_count=burst_count, + update=update, + _time_now_s=time_now_s, ) if not allowed: diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 4a3bfc4354..2dd94bae2b 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + from ._base import Config class RateLimitConfig(object): - def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): + def __init__( + self, + config: Dict[str, float], + defaults={"per_second": 0.17, "burst_count": 3.0}, + ): self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = config.get("burst_count", defaults["burst_count"]) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 3b781d9836..61dc4beafe 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError +from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID logger = logging.getLogger(__name__) @@ -44,11 +44,26 @@ class BaseHandler(object): self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() - self.ratelimiter = hs.get_ratelimiter() - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() self.clock = hs.get_clock() self.hs = hs + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = self.hs.config.rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if self.hs.config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_admin_redaction.per_second, + burst_count=self.hs.config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() @@ -70,7 +85,6 @@ class BaseHandler(object): Raises: LimitExceededError if the request should be ratelimited """ - time_now = self.clock.time() user_id = requester.user.to_string() # The AS user itself is never rate limited. @@ -83,48 +97,32 @@ class BaseHandler(object): if requester.app_service and not requester.app_service.is_rate_limited(): return + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + # Check if there is a per user override in the DB. override = yield self.store.get_ratelimit_for_user(user_id) if override: - # If overriden with a null Hz then ratelimiting has been entirely + # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user if not override.messages_per_second: return messages_per_second = override.messages_per_second burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) else: - # We default to different values if this is an admin redaction and - # the config is set - if is_admin_redaction and self.hs.config.rc_admin_redaction: - messages_per_second = self.hs.config.rc_admin_redaction.per_second - burst_count = self.hs.config.rc_admin_redaction.burst_count - else: - messages_per_second = self.hs.config.rc_message.per_second - burst_count = self.hs.config.rc_message.burst_count - - if is_admin_redaction and self.hs.config.rc_admin_redaction: - # If we have separate config for admin redactions we use a separate - # ratelimiter - allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( - user_id, - time_now, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) - else: - allowed, time_allowed = self.ratelimiter.can_do_action( + # Override rate and burst count per-user + self.request_ratelimiter.ratelimit( user_id, - time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 75b39e878c..119678e67b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -108,7 +108,11 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. - self._failed_uia_attempts_ratelimiter = Ratelimiter() + self._failed_uia_attempts_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) self._clock = self.hs.get_clock() @@ -196,13 +200,7 @@ class AuthHandler(BaseHandler): user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows flows = [[login_type] for login_type in self._supported_ui_auth_types] @@ -212,14 +210,8 @@ class AuthHandler(BaseHandler): flows, request, request_body, clientip, description ) except LoginError: - # Update the ratelimite to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). + self._failed_uia_attempts_ratelimiter.can_do_action(user_id) raise # find the completed login type diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 681f92cafd..649ca1f08a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -362,7 +362,6 @@ class EventCreationHandler(object): self.profile_handler = hs.get_profile_handler() self.event_builder_factory = hs.get_event_builder_factory() self.server_name = hs.hostname - self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 55a03e53ea..cd746be7c8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -425,14 +425,7 @@ class RegistrationHandler(BaseHandler): if not address: return - time_now = self.clock.time() - - self.ratelimiter.ratelimit( - address, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - ) + self.ratelimiter.ratelimit(address) def register_with_store( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 6ac7c5142b..dceb2792fa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -87,11 +87,22 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() - self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter() - self._account_ratelimiter = Ratelimiter() - self._failed_attempts_ratelimiter = Ratelimiter() + self._address_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + self._failed_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) def on_GET(self, request): flows = [] @@ -124,13 +135,7 @@ class LoginRestServlet(RestServlet): return 200, {} async def on_POST(self, request): - self._address_ratelimiter.ratelimit( - request.getClientIP(), - time_now_s=self.hs.clock.time(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, - update=True, - ) + self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: @@ -198,13 +203,7 @@ class LoginRestServlet(RestServlet): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) # Check for login providers that support 3pid login types ( @@ -238,13 +237,7 @@ class LoginRestServlet(RestServlet): # If it returned None but the 3PID was bound then we won't hit # this code path, which is fine as then the per-user ratelimit # will kick in below. - self._failed_attempts_ratelimiter.can_do_action( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action((medium, address)) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -263,11 +256,7 @@ class LoginRestServlet(RestServlet): # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, + qualified_user_id.lower(), update=False ) try: @@ -279,13 +268,7 @@ class LoginRestServlet(RestServlet): # limiter. Using `can_do_action` avoids us raising a ratelimit # exception and masking the LoginError. The actual ratelimiting # should have happened above. - self._failed_attempts_ratelimiter.can_do_action( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) raise result = await self._complete_login( @@ -318,13 +301,7 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - update=True, - ) + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index addd4cae19..b9ffe86b2a 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -26,7 +26,6 @@ import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, - LimitExceededError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -396,20 +395,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index ca2deb49bb..fe94836a2c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -242,9 +242,12 @@ class HomeServer(object): self.clock = Clock(reactor) self.distributor = Distributor() - self.ratelimiter = Ratelimiter() - self.admin_redaction_ratelimiter = Ratelimiter() - self.registration_ratelimiter = Ratelimiter() + + self.registration_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=config.rc_registration.per_second, + burst_count=config.rc_registration.burst_count, + ) self.datastores = None @@ -314,15 +317,9 @@ class HomeServer(object): def get_distributor(self): return self.distributor - def get_ratelimiter(self): - return self.ratelimiter - - def get_registration_ratelimiter(self): + def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self): - return self.admin_redaction_ratelimiter - def build_federation_client(self): return FederationClient(self) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 5ca4521ce3..e5efdfcd02 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -43,7 +43,7 @@ class FederationRateLimiter(object): self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): - """Used to ratelimit an incoming request from given host + """Used to ratelimit an incoming request from a given host Example usage: diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dbdd427cac..d580e729c5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,39 +1,97 @@ -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from tests import unittest class TestRatelimiter(unittest.TestCase): - def test_allowed(self): - limiter = Ratelimiter() - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 - ) + def test_allowed_via_can_do_action(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) - def test_pruning(self): - limiter = Ratelimiter() + def test_allowed_via_ratelimit(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=0) + + # Should raise + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key="test_id", _time_now_s=5) + self.assertEqual(context.exception.retry_after_ms, 5000) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=10) + + def test_allowed_via_can_do_action_and_overriding_parameters(self): + """Test that we can override options of can_do_action that would otherwise fail + an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,) + self.assertTrue(allowed) + self.assertEqual(10.0, time_allowed) + + # Second attempt, 1s later, will fail + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,) + self.assertFalse(allowed) + self.assertEqual(10.0, time_allowed) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. allowed, time_allowed = limiter.can_do_action( - key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, rate_hz=10.0 ) + self.assertTrue(allowed) + self.assertEqual(1.1, time_allowed) - self.assertIn("test_id_1", limiter.message_counts) - + # Similarly if we allow a burst of 10 actions allowed, time_allowed = limiter.can_do_action( - key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, burst_count=10 ) + self.assertTrue(allowed) + self.assertEqual(1.0, time_allowed) + + def test_allowed_via_ratelimit_and_overriding_parameters(self): + """Test that we can override options of the ratelimit method that would otherwise + fail an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + limiter.ratelimit(key=("test_id",), _time_now_s=0) + + # Second attempt, 1s later, will fail + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.assertEqual(context.exception.retry_after_ms, 9000) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. + limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + + # Similarly if we allow a burst of 10 actions + limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + + def test_pruning(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter.can_do_action(key="test_id_1", _time_now_s=0) + + self.assertIn("test_id_1", limiter.actions) + + limiter.can_do_action(key="test_id_2", _time_now_s=10) - self.assertNotIn("test_id_1", limiter.message_counts) + self.assertNotIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 8aa56f1496..29dd7d9c6e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,7 +14,7 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 32cb04645f..56497b8476 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from tests.replication._base import BaseStreamTestCase @@ -21,12 +21,7 @@ from tests.replication._base import BaseStreamTestCase class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), - ) - - hs.get_ratelimiter().can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(federation_client=Mock()) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index b54b06482b..f75520877f 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,7 +15,7 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import events, login, room @@ -40,11 +40,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(config=config) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 0f0f7ca72d..9033f09fd2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -29,7 +29,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] @@ -38,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_address(self): - self.hs.config.rc_login_address.burst_count = 5 - self.hs.config.rc_login_address.per_second = 0.17 - # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -80,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_account(self): - self.hs.config.rc_login_account.burst_count = 5 - self.hs.config.rc_login_account.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -119,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.config.rc_login_failed_attempts.burst_count = 5 - self.hs.config.rc_login_failed_attempts.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 7dd86d0c27..4886bbb401 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,7 +20,7 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from six.moves.urllib import parse as urlparse from twisted.internet import defer @@ -46,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) - self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 4bc3aaf02d..18260bb90e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,7 +16,7 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() def get_user_by_access_token(token=None, allow_guest=False): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 5637ce2090..7deaf5b24a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest +from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): @@ -146,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): url = self.url + b"?kind=guest" request, channel = self.make_request(b"POST", url, b"{}") @@ -168,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): params = { "username": "kermit" + str(i), -- cgit 1.5.1