From 509e381afa8c656e72f5fef3d651a9819794174a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Feb 2020 07:15:07 -0500 Subject: Clarify list/set/dict/tuple comprehensions and enforce via flake8 (#6957) Ensure good comprehension hygiene using flake8-comprehensions. --- tests/handlers/test_presence.py | 4 ++-- tests/handlers/test_typing.py | 6 +++--- tests/handlers/test_user_directory.py | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c171038df8..64915bafcd 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -338,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set([user_id]), now=now + state, is_mine=True, syncing_user_ids={user_id}, now=now ) self.assertIsNotNone(new_state) @@ -579,7 +579,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), states=[expected_state] + destinations={"server2", "server3"}, states=[expected_state] ) def _add_new_user(self, room_id, user_id): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 140cc0a3c2..07b204666e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,12 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_auth().check_user_in_room = check_user_in_room def get_joined_hosts_for_room(room_id): - return set(member.domain for member in self.room_members) + return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_users_in_room(room_id): - return set(str(u) for u in self.room_members) + return {str(u) for u in self.room_members} hs.get_state_handler().get_current_users_in_room = get_current_users_in_room @@ -257,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) + self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} self.assertEquals(self.event_source.get_current_key(), 0) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0a4765fff4..7b92bdbc47 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -226,7 +226,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -358,12 +358,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), set([(u1, room), (u2, room)])) + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms self.assertEqual( self._compress_shared(shares_private), - set([(u1, u3, private_room), (u3, u1, private_room)]), + {(u1, u3, private_room), (u3, u1, private_room)}, ) def test_initial_share_all_users(self): @@ -398,7 +398,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set([])) + self.assertEqual(self._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. -- cgit 1.4.1 From 1f773eec912e4908ab60f7823f5c0a024261af4d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 26 Feb 2020 15:33:26 +0000 Subject: Port PresenceHandler to async/await (#6991) --- changelog.d/6991.misc | 1 + synapse/handlers/message.py | 5 +- synapse/handlers/presence.py | 192 ++++++++++++++++-------------------- synapse/replication/tcp/resource.py | 6 +- synapse/server.pyi | 5 + tests/handlers/test_presence.py | 18 ++-- tox.ini | 1 + 7 files changed, 113 insertions(+), 115 deletions(-) create mode 100644 changelog.d/6991.misc (limited to 'tests/handlers') diff --git a/changelog.d/6991.misc b/changelog.d/6991.misc new file mode 100644 index 0000000000..5130f4e8af --- /dev/null +++ b/changelog.d/6991.misc @@ -0,0 +1 @@ +Port `synapse.handlers.presence` to async/await. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6be280952..a0103addd3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1016,11 +1016,10 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - @defer.inlineCallbacks - def _bump_active_time(self, user): + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() - yield presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0d6cf2b008..5526015ddb 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -24,11 +24,12 @@ The methods that define policy are: import logging from contextlib import contextmanager -from typing import Dict, Set +from typing import Dict, List, Set from six import iteritems, itervalues from prometheus_client import Counter +from typing_extensions import ContextManager from twisted.internet import defer @@ -42,10 +43,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +MYPY = False +if MYPY: + import synapse.server + logger = logging.getLogger(__name__) @@ -97,7 +102,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs - self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname self.clock = hs.get_clock() @@ -150,7 +154,7 @@ class PresenceHandler(object): # Set of users who have presence in the `user_to_current_state` that # have not yet been persisted - self.unpersisted_users_changes = set() + self.unpersisted_users_changes = set() # type: Set[str] hs.get_reactor().addSystemEventTrigger( "before", @@ -160,12 +164,11 @@ class PresenceHandler(object): self._on_shutdown, ) - self.serial_to_user = {} self._next_serial = 1 # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs = {} + self.user_to_num_current_syncs = {} # type: Dict[str, int] # Keeps track of the number of *ongoing* syncs on other processes. # While any sync is ongoing on another process the user will never @@ -213,8 +216,7 @@ class PresenceHandler(object): self._event_pos = self.store.get_current_events_token() self._event_processing = False - @defer.inlineCallbacks - def _on_shutdown(self): + async def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal timers. This allows changes to persist across startup without having to @@ -235,7 +237,7 @@ class PresenceHandler(object): if self.unpersisted_users_changes: - yield self.store.update_presence( + await self.store.update_presence( [ self.user_to_current_state[user_id] for user_id in self.unpersisted_users_changes @@ -243,8 +245,7 @@ class PresenceHandler(object): ) logger.info("Finished _on_shutdown") - @defer.inlineCallbacks - def _persist_unpersisted_changes(self): + async def _persist_unpersisted_changes(self): """We periodically persist the unpersisted changes, as otherwise they may stack up and slow down shutdown times. """ @@ -253,12 +254,11 @@ class PresenceHandler(object): if unpersisted: logger.info("Persisting %d unpersisted presence updates", len(unpersisted)) - yield self.store.update_presence( + await self.store.update_presence( [self.user_to_current_state[user_id] for user_id in unpersisted] ) - @defer.inlineCallbacks - def _update_states(self, new_states): + async def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. @@ -267,7 +267,7 @@ class PresenceHandler(object): with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't yield between now and when we've + # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. to_notify = {} # Changes we want to notify everyone about @@ -311,7 +311,7 @@ class PresenceHandler(object): if to_notify: notified_presence_counter.inc(len(to_notify)) - yield self._persist_and_notify(list(to_notify.values())) + await self._persist_and_notify(list(to_notify.values())) self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) @@ -326,7 +326,7 @@ class PresenceHandler(object): self._push_to_remotes(to_federation_ping.values()) - def _handle_timeouts(self): + async def _handle_timeouts(self): """Checks the presence of users that have timed out and updates as appropriate. """ @@ -368,10 +368,9 @@ class PresenceHandler(object): now=now, ) - return self._update_states(changes) + return await self._update_states(changes) - @defer.inlineCallbacks - def bump_presence_active_time(self, user): + async def bump_presence_active_time(self, user): """We've seen the user do something that indicates they're interacting with the app. """ @@ -383,16 +382,17 @@ class PresenceHandler(object): bump_active_time_counter.inc() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"last_active_ts": self.clock.time_msec()} if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def user_syncing(self, user_id, affect_presence=True): + async def user_syncing( + self, user_id: str, affect_presence: bool = True + ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -415,11 +415,11 @@ class PresenceHandler(object): curr_sync = self.user_to_num_current_syncs.get(user_id, 0) self.user_to_num_current_syncs[user_id] = curr_sync + 1 - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) if prev_state.state == PresenceState.OFFLINE: # If they're currently offline then bring them online, otherwise # just update the last sync times. - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( state=PresenceState.ONLINE, @@ -429,7 +429,7 @@ class PresenceHandler(object): ] ) else: - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -437,13 +437,12 @@ class PresenceHandler(object): ] ) - @defer.inlineCallbacks - def _end(): + async def _end(): try: self.user_to_num_current_syncs[user_id] -= 1 - prev_state = yield self.current_state_for_user(user_id) - yield self._update_states( + prev_state = await self.current_state_for_user(user_id) + await self._update_states( [ prev_state.copy_and_replace( last_user_sync_ts=self.clock.time_msec() @@ -480,8 +479,7 @@ class PresenceHandler(object): else: return set() - @defer.inlineCallbacks - def update_external_syncs_row( + async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec ): """Update the syncing users for an external process as a delta. @@ -494,8 +492,8 @@ class PresenceHandler(object): is_syncing (bool): Whether or not the user is now syncing sync_time_msec(int): Time in ms when the user was last syncing """ - with (yield self.external_sync_linearizer.queue(process_id)): - prev_state = yield self.current_state_for_user(user_id) + with (await self.external_sync_linearizer.queue(process_id)): + prev_state = await self.current_state_for_user(user_id) process_presence = self.external_process_to_current_syncs.setdefault( process_id, set() @@ -525,25 +523,24 @@ class PresenceHandler(object): process_presence.discard(user_id) if updates: - yield self._update_states(updates) + await self._update_states(updates) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() - @defer.inlineCallbacks - def update_external_syncs_clear(self, process_id): + async def update_external_syncs_clear(self, process_id): """Marks all users that had been marked as syncing by a given process as offline. Used when the process has stopped/disappeared. """ - with (yield self.external_sync_linearizer.queue(process_id)): + with (await self.external_sync_linearizer.queue(process_id)): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = yield self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users(process_presence) time_now_ms = self.clock.time_msec() - yield self._update_states( + await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) for prev_state in itervalues(prev_states) @@ -551,15 +548,13 @@ class PresenceHandler(object): ) self.external_process_last_updated_ms.pop(process_id, None) - @defer.inlineCallbacks - def current_state_for_user(self, user_id): + async def current_state_for_user(self, user_id): """Get the current presence state for a user. """ - res = yield self.current_state_for_users([user_id]) + res = await self.current_state_for_users([user_id]) return res[user_id] - @defer.inlineCallbacks - def current_state_for_users(self, user_ids): + async def current_state_for_users(self, user_ids): """Get the current presence state for multiple users. Returns: @@ -574,7 +569,7 @@ class PresenceHandler(object): if missing: # There are things not in our in memory cache. Lets pull them out of # the database. - res = yield self.store.get_presence_for_users(missing) + res = await self.store.get_presence_for_users(missing) states.update(res) missing = [user_id for user_id, state in iteritems(states) if not state] @@ -587,14 +582,13 @@ class PresenceHandler(object): return states - @defer.inlineCallbacks - def _persist_and_notify(self, states): + async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers """ - stream_id, max_token = yield self.store.update_presence(states) + stream_id, max_token = await self.store.update_presence(states) - parties = yield get_interested_parties(self.store, states) + parties = await get_interested_parties(self.store, states) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -606,9 +600,8 @@ class PresenceHandler(object): self._push_to_remotes(states) - @defer.inlineCallbacks - def notify_for_states(self, state, stream_id): - parties = yield get_interested_parties(self.store, [state]) + async def notify_for_states(self, state, stream_id): + parties = await get_interested_parties(self.store, [state]) room_ids_to_states, users_to_states = parties self.notifier.on_new_event( @@ -626,8 +619,7 @@ class PresenceHandler(object): """ self.federation.send_presence(states) - @defer.inlineCallbacks - def incoming_presence(self, origin, content): + async def incoming_presence(self, origin, content): """Called when we receive a `m.presence` EDU from a remote server. """ now = self.clock.time_msec() @@ -670,21 +662,19 @@ class PresenceHandler(object): new_fields["status_msg"] = push.get("status_msg", None) new_fields["currently_active"] = push.get("currently_active", False) - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) updates.append(prev_state.copy_and_replace(**new_fields)) if updates: federation_presence_counter.inc(len(updates)) - yield self._update_states(updates) + await self._update_states(updates) - @defer.inlineCallbacks - def get_state(self, target_user, as_event=False): - results = yield self.get_states([target_user.to_string()], as_event=as_event) + async def get_state(self, target_user, as_event=False): + results = await self.get_states([target_user.to_string()], as_event=as_event) return results[0] - @defer.inlineCallbacks - def get_states(self, target_user_ids, as_event=False): + async def get_states(self, target_user_ids, as_event=False): """Get the presence state for users. Args: @@ -695,7 +685,7 @@ class PresenceHandler(object): list """ - updates = yield self.current_state_for_users(target_user_ids) + updates = await self.current_state_for_users(target_user_ids) updates = list(updates.values()) for user_id in set(target_user_ids) - {u.user_id for u in updates}: @@ -713,8 +703,7 @@ class PresenceHandler(object): else: return updates - @defer.inlineCallbacks - def set_state(self, target_user, state, ignore_status_msg=False): + async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -730,7 +719,7 @@ class PresenceHandler(object): user_id = target_user.to_string() - prev_state = yield self.current_state_for_user(user_id) + prev_state = await self.current_state_for_user(user_id) new_fields = {"state": presence} @@ -741,16 +730,15 @@ class PresenceHandler(object): if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() - yield self._update_states([prev_state.copy_and_replace(**new_fields)]) + await self._update_states([prev_state.copy_and_replace(**new_fields)]) - @defer.inlineCallbacks - def is_visible(self, observed_user, observer_user): + async def is_visible(self, observed_user, observer_user): """Returns whether a user can see another user's presence. """ - observer_room_ids = yield self.store.get_rooms_for_user( + observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) - observed_room_ids = yield self.store.get_rooms_for_user( + observed_room_ids = await self.store.get_rooms_for_user( observed_user.to_string() ) @@ -759,8 +747,7 @@ class PresenceHandler(object): return False - @defer.inlineCallbacks - def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -775,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = yield self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id) return rows def notify_new_event(self): @@ -786,20 +773,18 @@ class PresenceHandler(object): if self._event_processing: return - @defer.inlineCallbacks - def _process_presence(): + async def _process_presence(): assert not self._event_processing self._event_processing = True try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._event_processing = False run_as_background_process("presence.notify_new_event", _process_presence) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "presence_delta"): @@ -812,10 +797,10 @@ class PresenceHandler(object): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) - yield self._handle_state_delta(deltas) + await self._handle_state_delta(deltas) self._event_pos = max_pos @@ -824,8 +809,7 @@ class PresenceHandler(object): max_pos ) - @defer.inlineCallbacks - def _handle_state_delta(self, deltas): + async def _handle_state_delta(self, deltas): """Process current state deltas to find new joins that need to be handled. """ @@ -846,13 +830,13 @@ class PresenceHandler(object): # joins. continue - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event or event.content.get("membership") != Membership.JOIN: # We only care about joins continue if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if ( prev_event and prev_event.content.get("membership") == Membership.JOIN @@ -860,10 +844,9 @@ class PresenceHandler(object): # Ignore changes to join events. continue - yield self._on_user_joined_room(room_id, state_key) + await self._on_user_joined_room(room_id, state_key) - @defer.inlineCallbacks - def _on_user_joined_room(self, room_id, user_id): + async def _on_user_joined_room(self, room_id, user_id): """Called when we detect a user joining the room via the current state delta stream. @@ -882,8 +865,8 @@ class PresenceHandler(object): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = yield self.current_state_for_user(user_id) - hosts = yield self.state.get_current_hosts_in_room(room_id) + state = await self.current_state_for_user(user_id) + hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. hosts = {host for host in hosts if host != self.server_name} @@ -903,10 +886,10 @@ class PresenceHandler(object): # TODO: Check that this is actually a new server joining the # room. - user_ids = yield self.state.get_current_users_in_room(room_id) + user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = yield self.current_state_for_users(user_ids) + states = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -996,9 +979,8 @@ class PresenceEventSource(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - @defer.inlineCallbacks @log_function - def get_new_events( + async def get_new_events( self, user, from_key, @@ -1045,7 +1027,7 @@ class PresenceEventSource(object): presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache - users_interested_in = yield self._get_interested_in(user, explicit_room_id) + users_interested_in = await self._get_interested_in(user, explicit_room_id) user_ids_changed = set() changed = None @@ -1071,7 +1053,7 @@ class PresenceEventSource(object): else: user_ids_changed = users_interested_in - updates = yield presence.current_state_for_users(user_ids_changed) + updates = await presence.current_state_for_users(user_ids_changed) if include_offline: return (list(updates.values()), max_token) @@ -1084,11 +1066,11 @@ class PresenceEventSource(object): def get_current_key(self): return self.store.get_current_presence_token() - def get_pagination_rows(self, user, pagination_config, key): - return self.get_new_events(user, from_key=None, include_offline=False) + async def get_pagination_rows(self, user, pagination_config, key): + return await self.get_new_events(user, from_key=None, include_offline=False) - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _get_interested_in(self, user, explicit_room_id, cache_context): + @cached(num_args=2, cache_context=True) + async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence updates for """ @@ -1096,13 +1078,13 @@ class PresenceEventSource(object): users_interested_in = set() users_interested_in.add(user_id) # So that we receive our own presence - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(users_who_share_room) if explicit_room_id: - user_ids = yield self.store.get_users_in_room( + user_ids = await self.store.get_users_in_room( explicit_room_id, on_invalidate=cache_context.invalidate ) users_interested_in.update(user_ids) @@ -1277,8 +1259,8 @@ def get_interested_parties(store, states): 2-tuple: `(room_ids_to_states, users_to_states)`, with each item being a dict of `entity_name` -> `[UserPresenceState]` """ - room_ids_to_states = {} - users_to_states = {} + room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]] + users_to_states = {} # type: Dict[str, List[UserPresenceState]] for state in states: room_ids = yield store.get_rooms_for_user(state.user_id) for room_id in room_ids: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce60ae2e07..ce9d1fae12 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -323,7 +323,11 @@ class ReplicationStreamer(object): # We need to tell the presence handler that the connection has been # lost so that it can handle any ongoing syncs on that connection. - self.presence_handler.update_external_syncs_clear(connection.conn_id) + run_as_background_process( + "update_external_syncs_clear", + self.presence_handler.update_external_syncs_clear, + connection.conn_id, + ) def _batch_updates(updates): diff --git a/synapse/server.pyi b/synapse/server.pyi index 40eabfe5d9..3844f0e12f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -3,6 +3,7 @@ import twisted.internet import synapse.api.auth import synapse.config.homeserver import synapse.crypto.keyring +import synapse.federation.federation_server import synapse.federation.sender import synapse.federation.transport.client import synapse.handlers @@ -107,5 +108,9 @@ class HomeServer(object): self, ) -> synapse.replication.tcp.client.ReplicationClientHandler: pass + def get_federation_registry( + self, + ) -> synapse.federation.federation_server.FederationHandlerRegistry: + pass def is_mine_id(self, domain_id: str) -> bool: pass diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 64915bafcd..05ea40a7de 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -494,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, "@test2:server") # Mark test2 as online, test will be offline with a last_active of 0 - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -543,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) # Mark test as online - self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + ) ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) # Add servers to the room diff --git a/tox.ini b/tox.ini index b715ea0bff..4ccfde01b5 100644 --- a/tox.ini +++ b/tox.ini @@ -183,6 +183,7 @@ commands = mypy \ synapse/events/spamcheck.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ synapse/logging/ \ -- cgit 1.4.1 From 3e99528f2bfaa686c4708fb8efcddce935b2397d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 26 Feb 2020 16:58:33 +0000 Subject: Store room version on invite (#6983) When we get an invite over federation, store the room version in the rooms table. The general idea here is that, when we pull the invite out again, we'll want to know what room_version it belongs to (so that we can later redact it if need be). So we need to store it somewhere... --- changelog.d/6983.misc | 1 + synapse/handlers/federation.py | 12 +++++++++++ synapse/replication/http/_base.py | 2 +- synapse/replication/http/federation.py | 36 +++++++++++++++++++++++++++++++- synapse/storage/data_stores/main/room.py | 20 ++++++++++++++++++ tests/app/test_openid_listener.py | 8 +++++++ tests/handlers/test_typing.py | 1 + 7 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 changelog.d/6983.misc (limited to 'tests/handlers') diff --git a/changelog.d/6983.misc b/changelog.d/6983.misc new file mode 100644 index 0000000000..08aa80bcd9 --- /dev/null +++ b/changelog.d/6983.misc @@ -0,0 +1 @@ +Refactoring work in preparation for changing the event redaction algorithm. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c2e6ee266d..38ab6a8fc3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -60,6 +60,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, + ReplicationStoreRoomOnInviteRestServlet, ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store @@ -160,8 +161,12 @@ class FederationHandler(BaseHandler): self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( hs ) + self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( + hs + ) else: self._device_list_updater = hs.get_device_handler().device_list_updater + self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite # When joining a room we need to queue any events for that room up self.room_queues = {} @@ -1537,6 +1542,13 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + # keep a record of the room version, if we don't yet know it. + # (this may get overwritten if we later get a different room version in a + # join dance). + await self._maybe_store_room_on_invite( + room_id=event.room_id, room_version=room_version + ) + event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 444eb7b7f4..1be1ccbdf3 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -44,7 +44,7 @@ class ReplicationEndpoint(object): """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` - (with an `/:txn_id` prefix for cached requests.), where NAME is a name, + (with a `/:txn_id` suffix for cached requests), where NAME is a name, PATH_ARGS are a tuple of parameters to be encoded in the URL. For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 49a3251372..8794720101 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import event_type_from_format_version from synapse.events.snapshot import EventContext from synapse.http.servlet import parse_json_object_from_request @@ -211,7 +212,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): Request format: - POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id + POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id {} """ @@ -238,8 +239,41 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): return 200, {} +class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): + """Called to clean up any data in DB for a given room, ready for the + server to join the room. + + Request format: + + POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id + + { + "room_version": "1", + } + """ + + NAME = "store_room_on_invite" + PATH_ARGS = ("room_id",) + + def __init__(self, hs): + super().__init__(hs) + + self.store = hs.get_datastore() + + @staticmethod + def _serialize_payload(room_id, room_version): + return {"room_version": room_version.identifier} + + async def _handle_request(self, request, room_id): + content = parse_json_object_from_request(request) + room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] + await self.store.maybe_store_room_on_invite(room_id, room_version) + return 200, {} + + def register_servlets(hs, http_server): ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationCleanRoomRestServlet(hs).register(http_server) + ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 70137dfbe4..e6c10c6316 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -1020,6 +1020,26 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") + async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + """ + When we receive an invite over federation, store the version of the room if we + don't already know the room version. + """ + await self.db.simple_upsert( + desc="maybe_store_room_on_invite", + table="rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={ + "room_version": room_version.identifier, + "is_public": False, + "creator": "", + }, + # rooms has a unique constraint on room_id, so no need to lock when doing an + # emulated upsert. + lock=False, + ) + @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 1fe048048b..89fcc3889a 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -29,6 +29,14 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): ) return hs + def default_config(self, name="test"): + conf = super().default_config(name) + # we're using FederationReaderServer, which uses a SlavedStore, so we + # have to tell the FederationHandler not to try to access stuff that is only + # in the primary store. + conf["worker_app"] = "yes" + return conf + @parameterized.expand( [ (["federation"], "auth_fail"), diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 07b204666e..51e2b37218 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -74,6 +74,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "set_received_txn_response", "get_destination_retry_timings", "get_devices_by_remote", + "maybe_store_room_on_invite", # Bits that user_directory needs "get_user_directory_stream_pos", "get_current_state_deltas", -- cgit 1.4.1 From 7dcbc33a1be04c46b930699c03c15bc759f4b22c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 3 Mar 2020 07:12:45 -0500 Subject: Validate the alt_aliases property of canonical alias events (#6971) --- changelog.d/6971.feature | 1 + synapse/api/errors.py | 1 + synapse/handlers/directory.py | 14 ++-- synapse/handlers/message.py | 47 ++++++++++- synapse/types.py | 15 ++-- tests/handlers/test_directory.py | 66 +++++++-------- tests/rest/client/v1/test_rooms.py | 160 +++++++++++++++++++++++++++++++++++++ tests/test_types.py | 2 +- 8 files changed, 254 insertions(+), 52 deletions(-) create mode 100644 changelog.d/6971.feature (limited to 'tests/handlers') diff --git a/changelog.d/6971.feature b/changelog.d/6971.feature new file mode 100644 index 0000000000..ccf02a61df --- /dev/null +++ b/changelog.d/6971.feature @@ -0,0 +1 @@ +Validate the alt_aliases property of canonical alias events. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 0c20601600..616942b057 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -66,6 +66,7 @@ class Codes(object): EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + BAD_ALIAS = "M_BAD_ALIAS" class CodeMessageException(RuntimeError): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 0b23ca919a..61eb49059b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import collections import logging import string from typing import List @@ -307,15 +305,17 @@ class DirectoryHandler(BaseHandler): send_update = True content.pop("alias", "") - # Filter alt_aliases for the removed alias. - alt_aliases = content.pop("alt_aliases", None) - # If the aliases are not a list (or not found) do not attempt to modify - # the list. - if isinstance(alt_aliases, collections.Sequence): + # Filter the alt_aliases property for the removed alias. Note that the + # value is not modified if alt_aliases is of an unexpected form. + alt_aliases = content.get("alt_aliases") + if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases: send_update = True alt_aliases = [alias for alias in alt_aliases if alias != alias_str] + if alt_aliases: content["alt_aliases"] = alt_aliases + else: + del content["alt_aliases"] if send_update: yield self.event_creation_handler.create_and_send_nonmember_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a0103addd3..0c84c6cec4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -888,19 +888,60 @@ class EventCreationHandler(object): yield self.base_handler.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) + # Validate a newly added alias or newly added alt_aliases. + + original_alias = None + original_alt_aliases = set() + + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_event = yield self.store.get_event(original_event_id) + + if original_event: + original_alias = original_event.content.get("alias", None) + original_alt_aliases = original_event.content.get("alt_aliases", []) + + # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - if room_alias_str: + directory_handler = self.hs.get_handlers().directory_handler + if room_alias_str and room_alias_str != original_alias: room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler mapping = yield directory_handler.get_association(room_alias) if mapping["room_id"] != event.room_id: raise SynapseError( 400, "Room alias %s does not point to the room" % (room_alias_str,), + Codes.BAD_ALIAS, ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM + ) + + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + room_alias = RoomAlias.from_string(alias_str) + mapping = yield directory_handler.get_association(room_alias) + + if mapping["room_id"] != event.room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" + % (room_alias_str,), + Codes.BAD_ALIAS, + ) + federation_handler = self.hs.get_handlers().federation_handler if event.type == EventTypes.Member: diff --git a/synapse/types.py b/synapse/types.py index f3cd465735..acf60baddc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -23,7 +23,7 @@ import attr from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError # define a version of typing.Collection that works on python 3.5 if sys.version_info[:3] >= (3, 6, 0): @@ -166,11 +166,13 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str): """Parse the string given by 's' into a structure object.""" if len(s) < 1 or s[0:1] != cls.SIGIL: raise SynapseError( - 400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL) + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) parts = s[1:].split(":", 1) @@ -179,6 +181,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom 400, "Expected %s of the form '%slocalname:domain'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, ) domain = parts[1] @@ -235,11 +238,13 @@ class GroupID(DomainSpecificString): def from_string(cls, s): group_id = super(GroupID, cls).from_string(s) if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty") + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) if contains_invalid_mxid_characters(group_id.localpart): raise SynapseError( - 400, "Group ID can only contain characters a-z, 0-9, or '=_-./'" + 400, + "Group ID can only contain characters a-z, 0-9, or '=_-./'", + Codes.INVALID_PARAM, ) return group_id diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 27b916aed4..3397cfa485 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -88,6 +88,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias_not_allowed(self): + """Removing an alias should be denied if a user does not have the proper permissions.""" room_id = "!8765qwer:test" self.get_success( self.store.create_room_alias_association(self.my_room, room_id, ["test"]) @@ -101,6 +102,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) def test_delete_alias(self): + """Removing an alias should work when a user does has the proper permissions.""" room_id = "!8765qwer:test" user_id = "@user:test" self.get_success( @@ -159,30 +161,42 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) self.test_alias = "#test:test" - self.room_alias = RoomAlias.from_string(self.test_alias) + self.room_alias = self._add_alias(self.test_alias) + + def _add_alias(self, alias: str) -> RoomAlias: + """Add an alias to the test room.""" + room_alias = RoomAlias.from_string(alias) # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( - self.room_alias, self.room_id, ["test"], self.admin_user + room_alias, self.room_id, ["test"], self.admin_user ) ) + return room_alias - def test_remove_alias(self): - """Removing an alias that is the canonical alias should remove it there too.""" - # Set this new alias as the canonical alias for this room + def _set_canonical_alias(self, content): + """Configure the canonical alias state on the room.""" self.helper.send_state( - self.room_id, - "m.room.canonical_alias", - {"alias": self.test_alias, "alt_aliases": [self.test_alias]}, - tok=self.admin_user_tok, + self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, ) - data = self.get_success( + def _get_canonical_alias(self): + """Get the canonical alias state of the room.""" + return self.get_success( self.state_handler.get_current_state( self.room_id, EventTypes.CanonicalAlias, "" ) ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self._set_canonical_alias( + {"alias": self.test_alias, "alt_aliases": [self.test_alias]} + ) + + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) @@ -193,11 +207,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) @@ -205,29 +215,17 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" - other_room_alias = RoomAlias.from_string(other_test_alias) - self.get_success( - self.store.create_room_alias_association( - other_room_alias, self.room_id, ["test"], self.admin_user - ) - ) + other_room_alias = self._add_alias(other_test_alias) # Set the alias as the canonical alias for this room. - self.helper.send_state( - self.room_id, - "m.room.canonical_alias", + self._set_canonical_alias( { "alias": self.test_alias, "alt_aliases": [self.test_alias, other_test_alias], - }, - tok=self.admin_user_tok, + } ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual( data["content"]["alt_aliases"], [self.test_alias, other_test_alias] @@ -240,11 +238,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - data = self.get_success( - self.state_handler.get_current_state( - self.room_id, EventTypes.CanonicalAlias, "" - ) - ) + data = self._get_canonical_alias() self.assertEqual(data["content"]["alias"], self.test_alias) self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2f3df5f88f..7dd86d0c27 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + request, channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + request, channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/test_types.py b/tests/test_types.py index 8d97c751ea..480bea1bdc 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -75,7 +75,7 @@ class GroupIDTestCase(unittest.TestCase): self.fail("Parsing '%s' should raise exception" % id_string) except SynapseError as exc: self.assertEqual(400, exc.code) - self.assertEqual("M_UNKNOWN", exc.errcode) + self.assertEqual("M_INVALID_PARAM", exc.errcode) class MapUsernameTestCase(unittest.TestCase): -- cgit 1.4.1 From 13892776ef7e0b1af2f82c9ca53f7bbd1c60d66f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 4 Mar 2020 11:30:46 -0500 Subject: Allow deleting an alias if the user has sufficient power level (#6986) --- changelog.d/6986.feature | 1 + synapse/api/auth.py | 9 +-- synapse/handlers/directory.py | 107 ++++++++++++++++++++++---------- tests/handlers/test_directory.py | 128 +++++++++++++++++++++++++++++++-------- tox.ini | 1 + 5 files changed, 182 insertions(+), 64 deletions(-) create mode 100644 changelog.d/6986.feature (limited to 'tests/handlers') diff --git a/changelog.d/6986.feature b/changelog.d/6986.feature new file mode 100644 index 0000000000..16dea8bd7f --- /dev/null +++ b/changelog.d/6986.feature @@ -0,0 +1 @@ +Users with a power level sufficient to modify the canonical alias of a room can now delete room aliases. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5ca18b4301..c1ade1333b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -539,7 +539,7 @@ class Auth(object): @defer.inlineCallbacks def check_can_change_room_list(self, room_id: str, user: UserID): - """Check if the user is allowed to edit the room's entry in the + """Determine whether the user is allowed to edit the room's entry in the published room list. Args: @@ -570,12 +570,7 @@ class Auth(object): ) user_level = event_auth.get_user_power_level(user_id, auth_events) - if user_level < send_level: - raise AuthError( - 403, - "This server requires you to be a moderator in the room to" - " edit its room list entry", - ) + return user_level >= send_level @staticmethod def has_access_token(request): diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 61eb49059b..1d842c369b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -15,7 +15,7 @@ import logging import string -from typing import List +from typing import Iterable, List, Optional from twisted.internet import defer @@ -28,6 +28,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.appservice import ApplicationService from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id from ._base import BaseHandler @@ -55,7 +56,13 @@ class DirectoryHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() @defer.inlineCallbacks - def _create_association(self, room_alias, room_id, servers=None, creator=None): + def _create_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Optional[Iterable[str]] = None, + creator: Optional[str] = None, + ): # general association creation for both human users and app services for wchar in string.whitespace: @@ -81,17 +88,21 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def create_association( - self, requester, room_alias, room_id, servers=None, check_membership=True, + self, + requester: Requester, + room_alias: RoomAlias, + room_id: str, + servers: Optional[List[str]] = None, + check_membership: bool = True, ): """Attempt to create a new alias Args: - requester (Requester) - room_alias (RoomAlias) - room_id (str) - servers (list[str]|None): List of servers that others servers - should try and join via - check_membership (bool): Whether to check if the user is in the room + requester + room_alias + room_id + servers: Iterable of servers that others servers should try and join via + check_membership: Whether to check if the user is in the room before the alias can be set (if the server's config requires it). Returns: @@ -145,15 +156,15 @@ class DirectoryHandler(BaseHandler): yield self._create_association(room_alias, room_id, servers, creator=user_id) @defer.inlineCallbacks - def delete_association(self, requester, room_alias): + def delete_association(self, requester: Requester, room_alias: RoomAlias): """Remove an alias from the directory (this is only meant for human users; AS users should call delete_appservice_association) Args: - requester (Requester): - room_alias (RoomAlias): + requester + room_alias Returns: Deferred[unicode]: room id that the alias used to point to @@ -189,16 +200,16 @@ class DirectoryHandler(BaseHandler): room_id = yield self._delete_association(room_alias) try: - yield self._update_canonical_alias( - requester, requester.user.to_string(), room_id, room_alias - ) + yield self._update_canonical_alias(requester, user_id, room_id, room_alias) except AuthError as e: logger.info("Failed to update alias events: %s", e) return room_id @defer.inlineCallbacks - def delete_appservice_association(self, service, room_alias): + def delete_appservice_association( + self, service: ApplicationService, room_alias: RoomAlias + ): if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, @@ -208,7 +219,7 @@ class DirectoryHandler(BaseHandler): yield self._delete_association(room_alias) @defer.inlineCallbacks - def _delete_association(self, room_alias): + def _delete_association(self, room_alias: RoomAlias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") @@ -217,7 +228,7 @@ class DirectoryHandler(BaseHandler): return room_id @defer.inlineCallbacks - def get_association(self, room_alias): + def get_association(self, room_alias: RoomAlias): room_id = None if self.hs.is_mine(room_alias): result = yield self.get_association_from_room_alias(room_alias) @@ -282,7 +293,9 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + def _update_canonical_alias( + self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias + ): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. @@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): + def get_association_from_room_alias(self, room_alias: RoomAlias): result = yield self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists @@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler): result = yield as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias, user_id=None): + def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have @@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler): return defer.succeed(True) @defer.inlineCallbacks - def _user_can_delete_alias(self, alias, user_id): + def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): + """Determine whether a user can delete an alias. + + One of the following must be true: + + 1. The user created the alias. + 2. The user is a server administrator. + 3. The user has a power-level sufficient to send a canonical alias event + for the current room. + + """ creator = yield self.store.get_room_alias_creator(alias.to_string()) if creator is not None and creator == user_id: return True - is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) - return is_admin + # Resolve the alias to the corresponding room. + room_mapping = yield self.get_association(alias) + room_id = room_mapping["room_id"] + if not room_id: + return False + + res = yield self.auth.check_can_change_room_list( + room_id, UserID.from_string(user_id) + ) + return res @defer.inlineCallbacks - def edit_published_room_list(self, requester, room_id, visibility): + def edit_published_room_list( + self, requester: Requester, room_id: str, visibility: str + ): """Edit the entry of the room in the published room list. requester - room_id (str) - visibility (str): "public" or "private" + room_id + visibility: "public" or "private" """ user_id = requester.user.to_string() @@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler): if room is None: raise SynapseError(400, "Unknown room") - yield self.auth.check_can_change_room_list(room_id, requester.user) + can_change_room_list = yield self.auth.check_can_change_room_list( + room_id, requester.user + ) + if not can_change_room_list: + raise AuthError( + 403, + "This server requires you to be a moderator in the room to" + " edit its room list entry", + ) making_public = visibility == "public" if making_public: @@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def edit_published_appservice_room_list( - self, appservice_id, network_id, room_id, visibility + self, appservice_id: str, network_id: str, room_id: str, visibility: str ): """Add or remove a room from the appservice/network specific public room list. Args: - appservice_id (str): ID of the appservice that owns the list - network_id (str): The ID of the network the list is associated with - room_id (str) - visibility (str): either "public" or "private" + appservice_id: ID of the appservice that owns the list + network_id: The ID of the network the list is associated with + room_id + visibility: either "public" or "private" """ if visibility not in ["public", "private"]: raise SynapseError(400, "Invalid visibility setting") diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 3397cfa485..5e40adba52 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,6 +18,7 @@ from mock import Mock from twisted.internet import defer +import synapse import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig @@ -87,52 +88,131 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_delete_alias_not_allowed(self): - """Removing an alias should be denied if a user does not have the proper permissions.""" - room_id = "!8765qwer:test" + def test_incoming_fed_query(self): + self.get_success( + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) + ) + + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) + ) + + self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestDeleteAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def _create_alias(self, user): + # Create a new alias to this room. self.get_success( - self.store.create_room_alias_association(self.my_room, room_id, ["test"]) + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], user + ) ) + def test_delete_alias_not_allowed(self): + """A user that doesn't meet the expected guidelines cannot delete an alias.""" + self._create_alias(self.admin_user) self.get_failure( self.handler.delete_association( - create_requester("@user:test"), self.my_room + create_requester(self.test_user), self.room_alias ), synapse.api.errors.AuthError, ) - def test_delete_alias(self): - """Removing an alias should work when a user does has the proper permissions.""" - room_id = "!8765qwer:test" - user_id = "@user:test" - self.get_success( - self.store.create_room_alias_association( - self.my_room, room_id, ["test"], user_id + def test_delete_alias_creator(self): + """An alias creator can delete their own alias.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias ) ) + self.assertEquals(self.room_id, result) + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_admin(self): + """A server admin can delete an alias created by another user.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias as the admin. result = self.get_success( - self.handler.delete_association(create_requester(user_id), self.my_room) + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) ) - self.assertEquals(room_id, result) + self.assertEquals(self.room_id, result) - # The alias should not be found. + # Confirm the alias is gone. self.get_failure( - self.handler.get_association(self.my_room), synapse.api.errors.SynapseError + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, ) - def test_incoming_fed_query(self): - self.get_success( - self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] - ) + def test_delete_alias_sufficient_power(self): + """A user with a sufficient power level should be able to delete an alias.""" + self._create_alias(self.admin_user) + + # Increase the user's power level. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"users": {self.test_user: 100}}, + tok=self.admin_user_tok, ) - response = self.get_success( - self.handler.on_directory_query({"room_alias": "#your-room:test"}) + # They can now delete the alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) ) + self.assertEquals(self.room_id, result) - self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) class CanonicalAliasTestCase(unittest.HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 097ebb8774..7622aa19f1 100644 --- a/tox.ini +++ b/tox.ini @@ -185,6 +185,7 @@ commands = mypy \ synapse/federation/federation_client.py \ synapse/federation/sender \ synapse/federation/transport \ + synapse/handlers/directory.py \ synapse/handlers/presence.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ -- cgit 1.4.1 From 1f5f3ae8b1c5db96d36ac7c104f13553bc4283da Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Sun, 8 Mar 2020 14:49:33 +0100 Subject: Add options to disable setting profile info for prevent changes. --- synapse/config/registration.py | 11 +++++++++++ synapse/handlers/profile.py | 10 ++++++++++ tests/handlers/test_profile.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) (limited to 'tests/handlers') diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 9bb3beedbc..d9f452dcea 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,6 +129,9 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.disable_set_displayname = config.get("disable_set_displayname", False) + self.disable_set_avatar_url = config.get("disable_set_avatar_url", False) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -330,6 +333,14 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process + # If enabled, don't let users set their own display names/avatars + # other than for the very first time (unless they are a server admin). + # Useful when provisioning users based on the contents of a 3rd party + # directory and to avoid ambiguities. + # + # disable_set_displayname: False + # disable_set_avatar_url: False + # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 50ce0c585b..fb7e84f3b8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,6 +157,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") + if not by_admin and self.hs.config.disable_set_displayname: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.display_name: + raise SynapseError(400, "Changing displayname is disabled on this server") + if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -218,6 +223,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") + if not by_admin and self.hs.config.disable_set_avatar_url: + profile = yield self.store.get_profileinfo(target_user.localpart) + if profile.avatar_url: + raise SynapseError(400, "Changing avatar url is disabled on this server") + if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..b85520c688 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,6 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.config = hs.config @defer.inlineCallbacks def test_get_my_name(self): @@ -90,6 +91,19 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.config.disable_set_displayname = True + + # Set first displayname is allowed, if displayname is null + self.store.set_profile_displayname(self.frank.localpart, "Frank") + + d = self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -147,3 +161,20 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.config.disable_set_avatar_url = True + + # Set first time avatar is allowed, if displayname is null + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + d = self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + + yield self.assertFailure(d, SynapseError) -- cgit 1.4.1 From 04f4b5f6f87fbba0b2f1a4f011c496de3021c81a Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 19:51:31 +0100 Subject: add tests --- tests/handlers/test_profile.py | 6 +- tests/rest/client/v2_alpha/test_account.py | 308 +++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+), 3 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index b85520c688..98b508c3d4 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -70,7 +70,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() - self.config = hs.config + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -93,7 +93,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_if_disabled(self): - self.config.disable_set_displayname = True + self.hs.config.disable_set_displayname = True # Set first displayname is allowed, if displayname is null self.store.set_profile_displayname(self.frank.localpart, "Frank") @@ -164,7 +164,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): - self.config.disable_set_avatar_url = True + self.hs.config.disable_set_avatar_url = True # Set first time avatar is allowed, if displayname is null self.store.set_profile_avatar_url( diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index c3facc00eb..ac9f200de3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -325,3 +325,311 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + self.email_attempts = [] + + def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): + self.email_attempts.append(msg) + return + + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_email(self): + """Test add mail to profile + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_add_email_if_disabled(self): + """Test add mail to profile if disabled + """ + self.hs.config.disable_3pid_changes = True + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test delete mail from profile + """ + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + { + "medium": "email", + "address": self.email + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test delete mail from profile if disabled + """ + self.hs.config.disable_3pid_changes = True + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + request, channel = self.make_request( + "POST", + b"account/3pid/delete", + { + "medium": "email", + "address": self.email + }, + access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("3PID changes disabled on this server", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email + """ + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + request, channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.render(request) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + + # Get user + request, channel = self.make_request( + "GET", self.url_3pid, access_token=self.user_id_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def _request_token(self, email, client_secret): + request, channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + request, channel = self.make_request("GET", path, shorthand=False) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) -- cgit 1.4.1 From 7e5f40e7716813f0d32e2efcb32df3c263fbfc63 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 21:00:36 +0100 Subject: fix tests --- tests/handlers/test_profile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 98b508c3d4..f8c0da5ced 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -96,7 +96,7 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.disable_set_displayname = True # Set first displayname is allowed, if displayname is null - self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") d = self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." @@ -167,7 +167,7 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.disable_set_avatar_url = True # Set first time avatar is allowed, if displayname is null - self.store.set_profile_avatar_url( + yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) -- cgit 1.4.1 From 885134529ffd95dd118d3228e69f0e3553f5a6a7 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Mar 2020 22:09:29 +0100 Subject: updates after review --- changelog.d/7053.feature | 2 +- docs/sample_config.yaml | 10 +++++----- synapse/config/registration.py | 16 ++++++++-------- synapse/handlers/profile.py | 8 ++++---- synapse/rest/client/v2_alpha/account.py | 18 ++++++++++++------ tests/handlers/test_profile.py | 6 +++--- tests/rest/client/v2_alpha/test_account.py | 17 +++++++---------- 7 files changed, 40 insertions(+), 37 deletions(-) (limited to 'tests/handlers') diff --git a/changelog.d/7053.feature b/changelog.d/7053.feature index 79955b9780..00f47b2a14 100644 --- a/changelog.d/7053.feature +++ b/changelog.d/7053.feature @@ -1 +1 @@ -Add options to disable setting profile info for prevent changes. \ No newline at end of file +Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d3ecffac7d..8333800a10 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1057,18 +1057,18 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process -# If enabled, don't let users set their own display names/avatars +# If disabled, don't let users set their own display names/avatars # other than for the very first time (unless they are a server admin). # Useful when provisioning users based on the contents of a 3rd party # directory and to avoid ambiguities. # -#disable_set_displayname: false -#disable_set_avatar_url: false +#enable_set_displayname: true +#enable_set_avatar_url: true -# If true, stop users from trying to change the 3PIDs associated with +# If false, stop users from trying to change the 3PIDs associated with # their accounts. # -#disable_3pid_changes: false +#enable_3pid_changes: true # Users who register on this homeserver will automatically be joined # to these rooms diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 1abc0a79af..d4897ec9b6 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,9 +129,9 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.disable_set_displayname = config.get("disable_set_displayname", False) - self.disable_set_avatar_url = config.get("disable_set_avatar_url", False) - self.disable_3pid_changes = config.get("disable_3pid_changes", False) + self.enable_set_displayname = config.get("enable_set_displayname", True) + self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + self.enable_3pid_changes = config.get("enable_3pid_changes", True) self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False @@ -334,18 +334,18 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process - # If enabled, don't let users set their own display names/avatars + # If disabled, don't let users set their own display names/avatars # other than for the very first time (unless they are a server admin). # Useful when provisioning users based on the contents of a 3rd party # directory and to avoid ambiguities. # - #disable_set_displayname: false - #disable_set_avatar_url: false + #enable_set_displayname: true + #enable_set_avatar_url: true - # If true, stop users from trying to change the 3PIDs associated with + # If false, stop users from trying to change the 3PIDs associated with # their accounts. # - #disable_3pid_changes: false + #enable_3pid_changes: true # Users who register on this homeserver will automatically be joined # to these rooms diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b049dd8e26..eb85dba015 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,11 +157,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and self.hs.config.disable_set_displayname: + if not by_admin and not self.hs.config.enable_set_displayname: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( - 400, "Changing displayname is disabled on this server" + 400, "Changing display name is disabled on this server", Codes.FORBIDDEN ) if len(new_displayname) > MAX_DISPLAYNAME_LEN: @@ -225,11 +225,11 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and self.hs.config.disable_set_avatar_url: + if not by_admin and not self.hs.config.enable_set_avatar_url: profile = yield self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( - 400, "Changing avatar url is disabled on this server" + 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN ) if len(new_avatar_url) > MAX_AVATAR_URL_LEN: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 97bddf36d9..e40136f2f3 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -599,8 +599,10 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -646,8 +648,10 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -749,8 +753,10 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): - if self.hs.config.disable_3pid_changes: - raise SynapseError(400, "3PID changes disabled on this server") + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f8c0da5ced..e600b9777b 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -93,7 +93,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name_if_disabled(self): - self.hs.config.disable_set_displayname = True + self.hs.config.enable_set_displayname = False # Set first displayname is allowed, if displayname is null yield self.store.set_profile_displayname(self.frank.localpart, "Frank") @@ -164,9 +164,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): - self.hs.config.disable_set_avatar_url = True + self.hs.config.enable_set_avatar_url = False - # Set first time avatar is allowed, if displayname is null + # Set first time avatar is allowed, if avatar is null yield self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index e178a53335..34e40a36d0 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -412,7 +413,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test add mail to profile if disabled """ - self.hs.config.disable_3pid_changes = True + self.hs.config.enable_3pid_changes = True client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -438,9 +439,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "3PID changes disabled on this server", channel.json_body["error"] - ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -486,7 +485,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test delete mail from profile if disabled """ - self.hs.config.disable_3pid_changes = True + self.hs.config.enable_3pid_changes = True # Add a threepid self.get_success( @@ -508,9 +507,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual( - "3PID changes disabled on this server", channel.json_body["error"] - ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -547,7 +544,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user request, channel = self.make_request( @@ -582,7 +579,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("No validated 3pid session found", channel.json_body["error"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user request, channel = self.make_request( -- cgit 1.4.1 From 6a35046363a6f5d41199256c80eef4ea7e385986 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 17 Mar 2020 11:25:01 +0000 Subject: Revert "Add options to disable setting profile info for prevent changes. (#7053)" This reverts commit 54dd28621b070ca67de9f773fe9a89e1f4dc19da, reversing changes made to 6640460d054e8f4444046a34bdf638921b31c01e. --- changelog.d/7053.feature | 1 - docs/sample_config.yaml | 13 -- synapse/config/registration.py | 17 -- synapse/handlers/profile.py | 16 -- synapse/rest/client/v2_alpha/account.py | 16 -- tests/handlers/test_profile.py | 33 +--- tests/rest/client/v2_alpha/test_account.py | 303 ----------------------------- 7 files changed, 1 insertion(+), 398 deletions(-) delete mode 100644 changelog.d/7053.feature (limited to 'tests/handlers') diff --git a/changelog.d/7053.feature b/changelog.d/7053.feature deleted file mode 100644 index 00f47b2a14..0000000000 --- a/changelog.d/7053.feature +++ /dev/null @@ -1 +0,0 @@ -Add options to prevent users from changing their profile or associated 3PIDs. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 91eff4c8ad..2ff0dd05a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1057,19 +1057,6 @@ account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process -# If disabled, don't let users set their own display names/avatars -# (unless they are a server admin) other than for the very first time. -# Useful when provisioning users based on the contents of a 3rd party -# directory and to avoid ambiguities. -# -#enable_set_displayname: true -#enable_set_avatar_url: true - -# If false, stop users from trying to change the 3PIDs associated with -# their accounts. -# -#enable_3pid_changes: true - # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ee737eb40d..9bb3beedbc 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -129,10 +129,6 @@ class RegistrationConfig(Config): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) - self.enable_set_displayname = config.get("enable_set_displayname", True) - self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) - self.enable_3pid_changes = config.get("enable_3pid_changes", True) - self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -334,19 +330,6 @@ class RegistrationConfig(Config): #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process - # If disabled, don't let users set their own display names/avatars - # (unless they are a server admin) other than for the very first time. - # Useful when provisioning users based on the contents of a 3rd party - # directory and to avoid ambiguities. - # - #enable_set_displayname: true - #enable_set_avatar_url: true - - # If false, stop users from trying to change the 3PIDs associated with - # their accounts. - # - #enable_3pid_changes: true - # Users who register on this homeserver will automatically be joined # to these rooms # diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 6aa1c0f5e0..50ce0c585b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -157,15 +157,6 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: - profile = yield self.store.get_profileinfo(target_user.localpart) - if profile.display_name: - raise SynapseError( - 400, - "Changing display name is disabled on this server", - Codes.FORBIDDEN, - ) - if len(new_displayname) > MAX_DISPLAYNAME_LEN: raise SynapseError( 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) @@ -227,13 +218,6 @@ class BaseProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: - profile = yield self.store.get_profileinfo(target_user.localpart) - if profile.avatar_url: - raise SynapseError( - 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN - ) - if len(new_avatar_url) > MAX_AVATAR_URL_LEN: raise SynapseError( 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e40136f2f3..dc837d6c75 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -599,11 +599,6 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -648,11 +643,6 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -748,16 +738,10 @@ class ThreepidDeleteRestServlet(RestServlet): def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index e600b9777b..d60c124eec 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -19,7 +19,7 @@ from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -70,7 +70,6 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() - self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -91,19 +90,6 @@ class ProfileTestCase(unittest.TestCase): "Frank Jr.", ) - @defer.inlineCallbacks - def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False - - # Set first displayname is allowed, if displayname is null - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") - - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." - ) - - yield self.assertFailure(d, SynapseError) - @defer.inlineCallbacks def test_set_my_name_noauth(self): d = self.handler.set_displayname( @@ -161,20 +147,3 @@ class ProfileTestCase(unittest.TestCase): (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) - - @defer.inlineCallbacks - def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False - - # Set first time avatar is allowed, if avatar is null - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" - ) - - d = self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", - ) - - yield self.assertFailure(d, SynapseError) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 99cc9163f3..c3facc00eb 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,6 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register @@ -326,305 +325,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(request.code, 200) - - -class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - self.email_attempts = [] - - def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) - return self.hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.user_id = self.register_user("kermit", "test") - self.user_id_tok = self.login("kermit", "test") - self.email = "test@example.com" - self.url_3pid = b"account/3pid" - - def test_add_email(self): - """Test add mail to profile - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_add_email_if_disabled(self): - """Test add mail to profile if disabled - """ - self.hs.config.enable_3pid_changes = False - - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email(self): - """Test delete mail from profile - """ - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - request, channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email_if_disabled(self): - """Test delete mail from profile if disabled - """ - self.hs.config.enable_3pid_changes = False - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - request, channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to add email without clicking the link - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - client_secret = "foobar" - session_id = "weasle" - - # Attempt to add email without even requesting an email - request, channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.render(request) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - request, channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, - ) - self.render(request) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def _request_token(self, email, client_secret): - request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - ) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - return channel.json_body["sid"] - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - request, channel = self.make_request("GET", path, shorthand=False) - self.render(request) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) -- cgit 1.4.1