summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_validity.py32
-rw-r--r--synapse/handlers/appservice.py10
-rw-r--r--synapse/handlers/auth.py2
-rw-r--r--synapse/handlers/cas_handler.py3
-rw-r--r--synapse/handlers/deactivate_account.py4
-rw-r--r--synapse/handlers/device.py15
-rw-r--r--synapse/handlers/devicemessage.py25
-rw-r--r--synapse/handlers/directory.py68
-rw-r--r--synapse/handlers/e2e_keys.py14
-rw-r--r--synapse/handlers/e2e_room_keys.py7
-rw-r--r--synapse/handlers/federation.py78
-rw-r--r--synapse/handlers/groups_local.py4
-rw-r--r--synapse/handlers/identity.py298
-rw-r--r--synapse/handlers/message.py14
-rw-r--r--synapse/handlers/pagination.py28
-rw-r--r--synapse/handlers/presence.py43
-rw-r--r--synapse/handlers/profile.py216
-rw-r--r--synapse/handlers/register.py349
-rw-r--r--synapse/handlers/room.py110
-rw-r--r--synapse/handlers/room_list.py4
-rw-r--r--synapse/handlers/room_member.py64
-rw-r--r--synapse/handlers/set_password.py3
-rw-r--r--synapse/handlers/sync.py34
-rw-r--r--synapse/handlers/typing.py69
-rw-r--r--synapse/handlers/user_directory.py6
25 files changed, 1110 insertions, 390 deletions
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py

index 590135d19c..0c2bcda4d0 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -20,6 +20,8 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import List +from twisted.internet import defer + from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process @@ -43,6 +45,8 @@ class AccountValidityHandler(object): self.clock = self.hs.get_clock() self._account_validity = self.hs.config.account_validity + self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory + self.profile_handler = self.hs.get_profile_handler() if ( self._account_validity.enabled @@ -86,6 +90,13 @@ class AccountValidityHandler(object): self.clock.looping_call(send_emails, 30 * 60 * 1000) + # If account_validity is enabled,check every hour to remove expired users from + # the user directory + if self._account_validity.enabled: + self.clock.looping_call( + self._mark_expired_users_as_inactive, 60 * 60 * 1000 + ) + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` @@ -266,4 +277,25 @@ class AccountValidityHandler(object): user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) + # Check if renewed users should be reintroduced to the user directory + if self._show_users_in_user_directory: + # Show the user in the directory again by setting them to active + await self.profile_handler.set_active( + [UserID.from_string(user_id)], True, True + ) + return expiration_ts + + @defer.inlineCallbacks + def _mark_expired_users_as_inactive(self): + """Iterate over active, expired users. Mark them as inactive in order to hide them + from the user directory. + + Returns: + Deferred + """ + # Get active, expired users + active_expired_users = yield self.store.get_expired_users() + + # Mark each as non-active + yield self.profile_handler.set_active(active_expired_users, False, True) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index fe62f78e67..f7d9fd621e 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -15,8 +15,6 @@ import logging -from six import itervalues - from prometheus_client import Counter from twisted.internet import defer @@ -116,6 +114,12 @@ class ApplicationServicesHandler(object): for service in services: self.scheduler.submit_event_for_as(service, event) + now = self.clock.time_msec() + ts = yield self.store.get_received_ts(event.event_id) + synapse.metrics.event_processing_lag_by_event.labels( + "appservice_sender" + ).observe(now - ts) + @defer.inlineCallbacks def handle_room_events(events): for event in events: @@ -125,7 +129,7 @@ class ApplicationServicesHandler(object): defer.gatherResults( [ run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) + for evs in events_by_room.values() ], consumeErrors=True, ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bb3b43d5ae..c3f86e7414 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -297,7 +297,7 @@ class AuthHandler(BaseHandler): # Convert the URI and method to strings. uri = request.uri.decode("utf-8") - method = request.uri.decode("utf-8") + method = request.method.decode("utf-8") # If there's no session ID, create a new session. if not sid: diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 64aaa1335c..76f213723a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -14,11 +14,10 @@ # limitations under the License. import logging +import urllib import xml.etree.ElementTree as ET from typing import Dict, Optional, Tuple -from six.moves import urllib - from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 2afb390a92..fe62c3f973 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -33,6 +33,7 @@ class DeactivateAccountHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() self._identity_handler = hs.get_handlers().identity_handler + self._profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() # Flag that indicates whether the process to part users from rooms is running @@ -104,6 +105,9 @@ class DeactivateAccountHandler(BaseHandler): await self.store.user_set_password_hash(user_id, None) + user = UserID.from_string(user_id) + await self._profile_handler.set_active([user], False, False) + # Add the user to a table of users pending deactivation (ie. # removal from all the rooms they're a member of) await self.store.add_user_pending_deactivation(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 230d170258..31346b56c3 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -17,8 +17,6 @@ import logging from typing import Any, Dict, Optional -from six import iteritems, itervalues - from twisted.internet import defer from synapse.api import errors @@ -159,7 +157,7 @@ class DeviceWorkerHandler(BaseHandler): # The user may have left the room # TODO: Check if they actually did or if we were just invited. if room_id not in room_ids: - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -182,7 +180,7 @@ class DeviceWorkerHandler(BaseHandler): log_kv( {"event": "encountered empty previous state", "room_id": room_id} ) - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -198,10 +196,10 @@ class DeviceWorkerHandler(BaseHandler): # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. - for state_dict in itervalues(prev_state_ids): + for state_dict in prev_state_ids.values(): member_event = state_dict.get((EventTypes.Member, user_id), None) if not member_event or member_event != current_member_id: - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue @@ -211,14 +209,14 @@ class DeviceWorkerHandler(BaseHandler): # If there has been any change in membership, include them in the # possibly changed list. We'll check if they are joined below, # and we're not toooo worried about spuriously adding users. - for key, event_id in iteritems(current_state_ids): + for key, event_id in current_state_ids.items(): etype, state_key = key if etype != EventTypes.Member: continue # check if this member has changed since any of the extremities # at the stream_ordering, and add them to the list if so. - for state_dict in itervalues(prev_state_ids): + for state_dict in prev_state_ids.values(): prev_event_id = state_dict.get(key, None) if not prev_event_id or prev_event_id != event_id: if state_key != user_id: @@ -693,6 +691,7 @@ class DeviceListUpdater(object): return False + @trace @defer.inlineCallbacks def _maybe_retry_device_resync(self): """Retry to resync device lists that are out of sync, except if another retry is diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 05c4b3eec0..610b08d00b 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -18,8 +18,6 @@ from typing import Any, Dict from canonicaljson import json -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -51,8 +49,7 @@ class DeviceMessageHandler(object): self._device_list_updater = hs.get_device_handler().device_list_updater - @defer.inlineCallbacks - def on_direct_to_device_edu(self, origin, content): + async def on_direct_to_device_edu(self, origin, content): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): @@ -82,11 +79,11 @@ class DeviceMessageHandler(object): } local_messages[user_id] = messages_by_device - yield self._check_for_unknown_devices( + await self._check_for_unknown_devices( message_type, sender_user_id, by_device ) - stream_id = yield self.store.add_messages_from_remote_to_device_inbox( + stream_id = await self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages ) @@ -94,14 +91,13 @@ class DeviceMessageHandler(object): "to_device_key", stream_id, users=local_messages.keys() ) - @defer.inlineCallbacks - def _check_for_unknown_devices( + async def _check_for_unknown_devices( self, message_type: str, sender_user_id: str, by_device: Dict[str, Dict[str, Any]], ): - """Checks inbound device messages for unkown remote devices, and if + """Checks inbound device messages for unknown remote devices, and if found marks the remote cache for the user as stale. """ @@ -115,7 +111,7 @@ class DeviceMessageHandler(object): requesting_device_ids.add(device_id) # Check if we are tracking the devices of the remote user. - room_ids = yield self.store.get_rooms_for_user(sender_user_id) + room_ids = await self.store.get_rooms_for_user(sender_user_id) if not room_ids: logger.info( "Received device message from remote device we don't" @@ -127,7 +123,7 @@ class DeviceMessageHandler(object): # If we are tracking check that we know about the sending # devices. - cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id) + cached_devices = await self.store.get_cached_devices_for_user(sender_user_id) unknown_devices = requesting_device_ids - set(cached_devices) if unknown_devices: @@ -136,15 +132,14 @@ class DeviceMessageHandler(object): sender_user_id, unknown_devices, ) - yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background run_in_background( self._device_list_updater.user_device_resync, sender_user_id ) - @defer.inlineCallbacks - def send_device_message(self, sender_user_id, message_type, messages): + async def send_device_message(self, sender_user_id, message_type, messages): set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} @@ -183,7 +178,7 @@ class DeviceMessageHandler(object): } log_kv({"local_messages": local_messages}) - stream_id = yield self.store.add_messages_to_device_inbox( + stream_id = await self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index f2f16b1e43..79a2df6201 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py
@@ -17,8 +17,6 @@ import logging import string from typing import Iterable, List, Optional -from twisted.internet import defer - from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( AuthError, @@ -55,8 +53,7 @@ class DirectoryHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() - @defer.inlineCallbacks - def _create_association( + async def _create_association( self, room_alias: RoomAlias, room_id: str, @@ -76,13 +73,13 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) servers = {get_domain_from_id(u) for u in users} if not servers: raise SynapseError(400, "Failed to get server list") - yield self.store.create_room_alias_association( + await self.store.create_room_alias_association( room_alias, room_id, servers, creator=creator ) @@ -93,7 +90,7 @@ class DirectoryHandler(BaseHandler): room_id: str, servers: Optional[List[str]] = None, check_membership: bool = True, - ): + ) -> None: """Attempt to create a new alias Args: @@ -103,9 +100,6 @@ class DirectoryHandler(BaseHandler): servers: Iterable of servers that others servers should try and join via check_membership: Whether to check if the user is in the room before the alias can be set (if the server's config requires it). - - Returns: - Deferred """ user_id = requester.user.to_string() @@ -148,7 +142,7 @@ class DirectoryHandler(BaseHandler): # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") - can_create = await self.can_modify_alias(room_alias, user_id=user_id) + can_create = self.can_modify_alias(room_alias, user_id=user_id) if not can_create: raise AuthError( 400, @@ -158,7 +152,9 @@ class DirectoryHandler(BaseHandler): await self._create_association(room_alias, room_id, servers, creator=user_id) - async def delete_association(self, requester: Requester, room_alias: RoomAlias): + async def delete_association( + self, requester: Requester, room_alias: RoomAlias + ) -> str: """Remove an alias from the directory (this is only meant for human users; AS users should call @@ -169,7 +165,7 @@ class DirectoryHandler(BaseHandler): room_alias Returns: - Deferred[unicode]: room id that the alias used to point to + room id that the alias used to point to Raises: NotFoundError: if the alias doesn't exist @@ -191,7 +187,7 @@ class DirectoryHandler(BaseHandler): if not can_delete: raise AuthError(403, "You don't have permission to delete the alias.") - can_delete = await self.can_modify_alias(room_alias, user_id=user_id) + can_delete = self.can_modify_alias(room_alias, user_id=user_id) if not can_delete: raise SynapseError( 400, @@ -208,8 +204,7 @@ class DirectoryHandler(BaseHandler): return room_id - @defer.inlineCallbacks - def delete_appservice_association( + async def delete_appservice_association( self, service: ApplicationService, room_alias: RoomAlias ): if not service.is_interested_in_alias(room_alias.to_string()): @@ -218,29 +213,27 @@ class DirectoryHandler(BaseHandler): "This application service has not reserved this kind of alias", errcode=Codes.EXCLUSIVE, ) - yield self._delete_association(room_alias) + await self._delete_association(room_alias) - @defer.inlineCallbacks - def _delete_association(self, room_alias: RoomAlias): + async def _delete_association(self, room_alias: RoomAlias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") - room_id = yield self.store.delete_room_alias(room_alias) + room_id = await self.store.delete_room_alias(room_alias) return room_id - @defer.inlineCallbacks - def get_association(self, room_alias: RoomAlias): + async def get_association(self, room_alias: RoomAlias): room_id = None if self.hs.is_mine(room_alias): - result = yield self.get_association_from_room_alias(room_alias) + result = await self.get_association_from_room_alias(room_alias) if result: room_id = result.room_id servers = result.servers else: try: - result = yield self.federation.make_query( + result = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, @@ -265,7 +258,7 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND, ) - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) @@ -277,13 +270,12 @@ class DirectoryHandler(BaseHandler): return {"room_id": room_id, "servers": servers} - @defer.inlineCallbacks - def on_directory_query(self, args): + async def on_directory_query(self, args): room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room Alias is not hosted on this homeserver") - result = yield self.get_association_from_room_alias(room_alias) + result = await self.get_association_from_room_alias(room_alias) if result is not None: return {"room_id": result.room_id, "servers": result.servers} @@ -344,16 +336,15 @@ class DirectoryHandler(BaseHandler): ratelimit=False, ) - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias: RoomAlias): - result = yield self.store.get_association_from_room_alias(room_alias) + async def get_association_from_room_alias(self, room_alias: RoomAlias): + result = await self.store.get_association_from_room_alias(room_alias) if not result: # Query AS to see if it exists as_handler = self.appservice_handler - result = yield as_handler.query_room_alias_exists(room_alias) + result = await as_handler.query_room_alias_exists(room_alias) return result - def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): + def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool: # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have @@ -366,12 +357,12 @@ class DirectoryHandler(BaseHandler): for service in interested_services: if user_id == service.sender: # this user IS the app service so they can do whatever they like - return defer.succeed(True) + return True elif service.is_exclusive_alias(alias.to_string()): # another service has an exclusive lock on this alias. - return defer.succeed(False) + return False # either no interested services, or no service with an exclusive lock - return defer.succeed(True) + return True async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): """Determine whether a user can delete an alias. @@ -459,8 +450,7 @@ class DirectoryHandler(BaseHandler): await self.store.set_room_is_public(room_id, making_public) - @defer.inlineCallbacks - def edit_published_appservice_room_list( + async def edit_published_appservice_room_list( self, appservice_id: str, network_id: str, room_id: str, visibility: str ): """Add or remove a room from the appservice/network specific public @@ -475,7 +465,7 @@ class DirectoryHandler(BaseHandler): if visibility not in ["public", "private"]: raise SynapseError(400, "Invalid visibility setting") - yield self.store.set_room_is_public_appservice( + await self.store.set_room_is_public_appservice( room_id, appservice_id, network_id, visibility == "public" ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 774a252619..a7e60cbc26 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -17,8 +17,6 @@ import logging -from six import iteritems - import attr from canonicaljson import encode_canonical_json, json from signedjson.key import decode_verify_key_bytes @@ -135,7 +133,7 @@ class E2eKeysHandler(object): remote_queries_not_in_cache = {} if remote_queries: query_list = [] - for user_id, device_ids in iteritems(remote_queries): + for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) else: @@ -145,9 +143,9 @@ class E2eKeysHandler(object): user_ids_not_in_cache, remote_results, ) = yield self.store.get_user_devices_from_cache(query_list) - for user_id, devices in iteritems(remote_results): + for user_id, devices in remote_results.items(): user_devices = results.setdefault(user_id, {}) - for device_id, device in iteritems(devices): + for device_id, device in devices.items(): keys = device.get("keys", None) device_display_name = device.get("device_display_name", None) if keys: @@ -446,9 +444,9 @@ class E2eKeysHandler(object): ",".join( ( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) + for user_id, user_keys in json_result.items() + for device_id, device_keys in user_keys.items() + for key_id, _ in device_keys.items() ) ), ) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 9abaf13b8f..f55470a707 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py
@@ -16,8 +16,6 @@ import logging -from six import iteritems - from twisted.internet import defer from synapse.api.errors import ( @@ -205,8 +203,8 @@ class E2eRoomKeysHandler(object): ) to_insert = [] # batch the inserts together changed = False # if anything has changed, we need to update the etag - for room_id, room in iteritems(room_keys["rooms"]): - for session_id, room_key in iteritems(room["sessions"]): + for room_id, room in room_keys["rooms"].items(): + for session_id, room_key in room["sessions"].items(): if not isinstance(room_key["is_verified"], bool): msg = ( "is_verified must be a boolean in keys for session %s in" @@ -351,6 +349,7 @@ class E2eRoomKeysHandler(object): raise res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) + res["etag"] = str(res["etag"]) return res @trace diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d0b62f4cf2..8a552ca7b5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -19,12 +19,9 @@ import itertools import logging +from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Sequence, Tuple -import six -from six import iteritems, itervalues -from six.moves import http_client, zip - import attr from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json @@ -33,7 +30,12 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RejectedReason +from synapse.api.constants import ( + EventTypes, + Membership, + RejectedReason, + RoomEncryptionAlgorithms, +) from synapse.api.errors import ( AuthError, CodeMessageException, @@ -171,7 +173,7 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info("handling received PDU: %s", pdu) + logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) # We reprocess pdus when we have seen them only as outliers existing = await self.store.get_event( @@ -286,6 +288,14 @@ class FederationHandler(BaseHandler): room_id, event_id, ) + elif missing_prevs: + logger.info( + "[%s %s] Not recursively fetching %d missing prev_events: %s", + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), + ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -330,12 +340,6 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) - logger.info( - "Event %s is missing prev_events: calculating state for a " - "backwards extremity", - event_id, - ) - # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. event_map = {event_id: pdu} @@ -353,7 +357,10 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "Requesting state at missing prev_event %s", event_id, + "[%s %s] Requesting state at missing prev_event %s", + room_id, + event_id, + p, ) with nested_logging_context(p): @@ -374,6 +381,7 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(room_id) state_map = await resolve_events_with_store( + self.clock, room_id, room_version, state_maps, @@ -387,13 +395,11 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, + list(state_map.values()), get_prev_content=False, ) event_map.update(evs) - state = [event_map[e] for e in six.itervalues(state_map)] + state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "[%s %s] Error attempting to resolve state at missing " @@ -742,7 +748,10 @@ class FederationHandler(BaseHandler): if device: keys = device.get("keys", {}).get("keys", {}) - if event.content.get("algorithm") == "m.megolm.v1.aes-sha2": + if ( + event.content.get("algorithm") + == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 + ): # For this algorithm we expect a curve25519 key. key_name = "curve25519:%s" % (device_id,) current_keys = [keys.get(key_name)] @@ -1001,7 +1010,7 @@ class FederationHandler(BaseHandler): """ joined_users = [ (state_key, int(event.depth)) - for (e_type, state_key), event in iteritems(state) + for (e_type, state_key), event in state.items() if e_type == EventTypes.Member and event.membership == Membership.JOIN ] @@ -1091,16 +1100,16 @@ class FederationHandler(BaseHandler): states = dict(zip(event_ids, [s.state for s in states])) state_map = await self.store.get_events( - [e_id for ids in itervalues(states) for e_id in itervalues(ids)], + [e_id for ids in states.values() for e_id in ids.values()], get_prev_content=False, ) states = { key: { k: state_map[e_id] - for k, e_id in iteritems(state_dict) + for k, e_id in state_dict.items() if e_id in state_map } - for key, state_dict in iteritems(states) + for key, state_dict in states.items() } for e_id, _ in sorted_extremeties_tuple: @@ -1188,7 +1197,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.prev_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") if len(ev.auth_event_ids()) > 10: logger.warning( @@ -1196,7 +1205,7 @@ class FederationHandler(BaseHandler): ev.event_id, len(ev.auth_event_ids()), ) - raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events") + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): """ Sends the invite to the remote server for signing. @@ -1517,8 +1526,15 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") + is_published = await self.store.is_room_published(event.room_id) + if not self.spam_checker.user_may_invite( - event.sender, event.state_key, event.room_id + event.sender, + event.state_key, + None, + room_id=event.room_id, + new_room=False, + published_room=is_published, ): raise SynapseError( 403, "This user is not permitted to send invites to this server/user" @@ -1539,7 +1555,7 @@ class FederationHandler(BaseHandler): # block any attempts to invite the server notices mxid if event.state_key == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a @@ -1725,7 +1741,7 @@ class FederationHandler(BaseHandler): state_groups = await self.state_store.get_state_groups(room_id, [event_id]) if state_groups: - _, state = list(iteritems(state_groups)).pop() + _, state = list(state_groups.items()).pop() results = {(e.type, e.state_key): e for e in state} if event.is_state(): @@ -2088,7 +2104,7 @@ class FederationHandler(BaseHandler): room_version, state_sets, event ) current_state_ids = { - k: e.event_id for k, e in iteritems(current_state_ids) + k: e.event_id for k, e in current_state_ids.items() } else: current_state_ids = await self.state_handler.get_current_state_ids( @@ -2104,7 +2120,7 @@ class FederationHandler(BaseHandler): # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) current_state_ids = [ - e for k, e in iteritems(current_state_ids) if k in auth_types + e for k, e in current_state_ids.items() if k in auth_types ] current_auth_events = await self.store.get_events(current_state_ids) @@ -2420,7 +2436,7 @@ class FederationHandler(BaseHandler): else: event_key = None state_updates = { - k: a.event_id for k, a in iteritems(auth_events) if k != event_key + k: a.event_id for k, a in auth_events.items() if k != event_key } current_state_ids = await context.get_current_state_ids() @@ -2431,7 +2447,7 @@ class FederationHandler(BaseHandler): prev_state_ids = await context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) - prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) + prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) # create a new state group as a delta from the existing one. prev_group = context.state_group diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ebe8d25bd8..7cb106e365 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -16,8 +16,6 @@ import logging -from six import iteritems - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.types import get_domain_from_id @@ -227,7 +225,7 @@ class GroupsLocalWorkerHandler(object): results = {} failed_results = [] - for destination, dest_user_ids in iteritems(destinations): + for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( destination, list(dest_user_ids) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 4ba0042768..a77088e295 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 +from twisted.internet import defer from twisted.internet.error import TimeoutError from synapse.api.errors import ( @@ -32,6 +33,7 @@ from synapse.api.errors import ( CodeMessageException, Codes, HttpResponseException, + ProxiedRequestError, SynapseError, ) from synapse.config.emailconfig import ThreepidBehaviour @@ -43,29 +45,34 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -id_server_scheme = "https://" - class IdentityHandler(BaseHandler): def __init__(self, hs): super(IdentityHandler, self).__init__(hs) - self.http_client = SimpleHttpClient(hs) + self.hs = hs + self.http_client = hs.get_simple_http_client() # We create a blacklisting instance of SimpleHttpClient for contacting identity # servers specified by clients self.blacklisting_http_client = SimpleHttpClient( hs, ip_blacklist=hs.config.federation_ip_range_blacklist ) self.federation_http_client = hs.get_http_client() - self.hs = hs - async def threepid_from_creds(self, id_server, creds): + self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers) + self.trust_any_id_server_just_for_testing_do_not_use = ( + hs.config.use_insecure_ssl_client_just_for_testing_do_not_use + ) + self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls + self._enable_lookup = hs.config.enable_3pid_lookup + + async def threepid_from_creds(self, id_server_url, creds): """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server Args: - id_server (str): The identity server to validate 3PIDs against. Must be a + id_server_url (str): The identity server to validate 3PIDs against. Must be a complete URL including the protocol (http(s)://) creds (dict[str, str]): Dictionary containing the following keys: @@ -92,7 +99,14 @@ class IdentityHandler(BaseHandler): query_params = {"sid": session_id, "client_secret": client_secret} - url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) + + url = "%s%s" % ( + id_server_url, + "/_matrix/identity/api/v1/3pid/getValidated3pid", + ) try: data = await self.http_client.get_json(url, query_params) @@ -101,7 +115,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info( "%s returned %i for threepid validation for: %s", - id_server, + id_server_url, e.code, creds, ) @@ -115,7 +129,7 @@ class IdentityHandler(BaseHandler): if "medium" in data: return data - logger.info("%s reported non-validated threepid: %s", id_server, creds) + logger.info("%s reported non-validated threepid: %s", id_server_url, creds) return None async def bind_threepid( @@ -146,14 +160,19 @@ class IdentityHandler(BaseHandler): if id_access_token is None: use_v2 = False + # if we have a rewrite rule set for the identity server, + # apply it now, but only for sending the request (not + # storing in the database). + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + # Decide which API endpoint URLs to use headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} if use_v2: - bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) + bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,) headers["Authorization"] = create_id_access_token_header(id_access_token) else: - bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) + bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,) try: # Use the blacklisting http client as this call is only to identity servers @@ -238,9 +257,6 @@ class IdentityHandler(BaseHandler): Deferred[bool]: True on success, otherwise False if the identity server doesn't support unbinding """ - url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") - content = { "mxid": mxid, "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, @@ -249,6 +265,7 @@ class IdentityHandler(BaseHandler): # we abuse the federation http client to sign the request, but we have to send it # using the normal http client since we don't want the SRV lookup and want normal # 'browser-like' HTTPS. + url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") auth_headers = self.federation_http_client.build_auth_headers( destination=None, method="POST", @@ -258,6 +275,15 @@ class IdentityHandler(BaseHandler): ) headers = {b"Authorization": auth_headers} + # if we have a rewrite rule set for the identity server, + # apply it now. + # + # Note that destination_is has to be the real id_server, not + # the server we connect to. + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,) + try: # Use the blacklisting http client as this call is only to identity servers # provided by a client @@ -371,15 +397,34 @@ class IdentityHandler(BaseHandler): return session_id + def rewrite_id_server_url(self, url: str, add_https=False) -> str: + """Given an identity server URL, optionally add a protocol scheme + before rewriting it according to the rewrite_identity_server_urls + config option + + Adds https:// to the URL if specified, then tries to rewrite the + url. Returns either the rewritten URL or the URL with optional + protocol scheme additions. + """ + rewritten_url = url + if add_https: + rewritten_url = "https://" + rewritten_url + + rewritten_url = self.rewrite_identity_server_urls.get( + rewritten_url, rewritten_url + ) + logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url) + return rewritten_url + async def requestEmailToken( - self, id_server, email, client_secret, send_attempt, next_link=None + self, id_server_url, email, client_secret, send_attempt, next_link=None ): """ Request an external server send an email on our behalf for the purposes of threepid validation. Args: - id_server (str): The identity server to proxy to + id_server_url (str): The identity server to proxy to email (str): The email to send the message to client_secret (str): The unique client_secret sends by the user send_attempt (int): Which attempt this is @@ -393,6 +438,11 @@ class IdentityHandler(BaseHandler): "client_secret": client_secret, "send_attempt": send_attempt, } + + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) + if next_link: params["next_link"] = next_link @@ -407,7 +457,8 @@ class IdentityHandler(BaseHandler): try: data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/email/requestToken", + "%s/_matrix/identity/api/v1/validate/email/requestToken" + % (id_server_url,), params, ) return data @@ -419,7 +470,7 @@ class IdentityHandler(BaseHandler): async def requestMsisdnToken( self, - id_server, + id_server_url, country, phone_number, client_secret, @@ -430,7 +481,7 @@ class IdentityHandler(BaseHandler): Request an external server send an SMS message on our behalf for the purposes of threepid validation. Args: - id_server (str): The identity server to proxy to + id_server_url (str): The identity server to proxy to country (str): The country code of the phone number phone_number (str): The number to send the message to client_secret (str): The unique client_secret sends by the user @@ -458,9 +509,13 @@ class IdentityHandler(BaseHandler): "details and update your config file." ) + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) try: data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", + "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" + % (id_server_url,), params, ) except HttpResponseException as e: @@ -554,6 +609,89 @@ class IdentityHandler(BaseHandler): logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) raise SynapseError(400, "Error contacting the identity server") + # TODO: The following two methods are used for proxying IS requests using + # the CS API. They should be consolidated with those in RoomMemberHandler + # https://github.com/matrix-org/synapse-dinsic/issues/25 + + @defer.inlineCallbacks + def proxy_lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + Deferred[dict]: The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" + ) + + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + try: + data = yield self.http_client.get_json( + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), + {"medium": medium, "address": address}, + ) + + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + yield self._verify_any_signature(data, id_server) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %s: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + return data + + @defer.inlineCallbacks + def proxy_bulk_lookup_3pid(self, id_server, threepids): + """Looks up given 3pids in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + threepids ([[str, str]]): The third party identifiers to lookup, as + a list of 2-string sized lists ([medium, address]). + + Returns: + Deferred[dict]: The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" + ) + + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + try: + data = yield self.http_client.post_json_get_json( + "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,), + {"threepids": threepids}, + ) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %s: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + defer.returnValue(data) + async def lookup_3pid(self, id_server, medium, address, id_access_token=None): """Looks up a 3pid in the passed identity server. @@ -568,10 +706,13 @@ class IdentityHandler(BaseHandler): Returns: str|None: the matrix ID of the 3pid, or None if it is not recognized. """ + # Rewrite id_server URL if necessary + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + if id_access_token is not None: try: results = await self._lookup_3pid_v2( - id_server, id_access_token, medium, address + id_server_url, id_access_token, medium, address ) return results @@ -589,14 +730,15 @@ class IdentityHandler(BaseHandler): logger.warning("Error when looking up hashing details: %s", e) return None - return await self._lookup_3pid_v1(id_server, medium, address) + return await self._lookup_3pid_v1(id_server, id_server_url, medium, address) - async def _lookup_3pid_v1(self, id_server, medium, address): + async def _lookup_3pid_v1(self, id_server, id_server_url, medium, address): """Looks up a 3pid in the passed identity server using v1 lookup. Args: id_server (str): The server name (including port, if required) of the identity server to use. + id_server_url (str): The actual, reachable domain of the id server medium (str): The type of the third party identifier (e.g. "email"). address (str): The third party identifier (e.g. "foo@example.com"). @@ -604,8 +746,8 @@ class IdentityHandler(BaseHandler): str: the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), + data = await self.http_client.get_json( + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), {"medium": medium, "address": address}, ) @@ -621,12 +763,11 @@ class IdentityHandler(BaseHandler): return None - async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address): + async def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address): """Looks up a 3pid in the passed identity server using v2 lookup. Args: - id_server (str): The server name (including port, if required) - of the identity server to use. + id_server_url (str): The protocol scheme and domain of the id server id_access_token (str): The access token to authenticate to the identity server with medium (str): The type of the third party identifier (e.g. "email"). address (str): The third party identifier (e.g. "foo@example.com"). @@ -636,8 +777,8 @@ class IdentityHandler(BaseHandler): """ # Check what hashing details are supported by this identity server try: - hash_details = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), + hash_details = await self.http_client.get_json( + "%s/_matrix/identity/v2/hash_details" % (id_server_url,), {"access_token": id_access_token}, ) except TimeoutError: @@ -645,15 +786,14 @@ class IdentityHandler(BaseHandler): if not isinstance(hash_details, dict): logger.warning( - "Got non-dict object when checking hash details of %s%s: %s", - id_server_scheme, - id_server, + "Got non-dict object when checking hash details of %s: %s", + id_server_url, hash_details, ) raise SynapseError( 400, - "Non-dict object from %s%s during v2 hash_details request: %s" - % (id_server_scheme, id_server, hash_details), + "Non-dict object from %s during v2 hash_details request: %s" + % (id_server_url, hash_details), ) # Extract information from hash_details @@ -667,8 +807,8 @@ class IdentityHandler(BaseHandler): ): raise SynapseError( 400, - "Invalid hash details received from identity server %s%s: %s" - % (id_server_scheme, id_server, hash_details), + "Invalid hash details received from identity server %s: %s" + % (id_server_url, hash_details), ) # Check if any of the supported lookup algorithms are present @@ -690,7 +830,7 @@ class IdentityHandler(BaseHandler): else: logger.warning( "None of the provided lookup algorithms of %s are supported: %s", - id_server, + id_server_url, supported_lookup_algorithms, ) raise SynapseError( @@ -703,8 +843,8 @@ class IdentityHandler(BaseHandler): headers = {"Authorization": create_id_access_token_header(id_access_token)} try: - lookup_results = await self.blacklisting_http_client.post_json_get_json( - "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), + lookup_results = await self.http_client.post_json_get_json( + "%s/_matrix/identity/v2/lookup" % (id_server_url,), { "addresses": [lookup_value], "algorithm": lookup_algorithm, @@ -731,30 +871,31 @@ class IdentityHandler(BaseHandler): mxid = lookup_results["mappings"].get(lookup_value) return mxid - async def _verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - try: - key_data = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" - % (id_server_scheme, server_hostname, key_name) - ) - except TimeoutError: - raise SynapseError(500, "Timed out contacting identity server") + async def _verify_any_signature(self, data, id_server): + if id_server not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (id_server,)) + + for key_name, signature in data["signatures"][id_server].items(): + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + key_data = await self.http_client.get_json( + "%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name) + ) if "public_key" not in key_data: raise AuthError( - 401, "No public key named %s from %s" % (key_name, server_hostname) + 401, "No public key named %s from %s" % (key_name, id_server) ) verify_signed_json( data, - server_hostname, + id_server, decode_verify_key_bytes( key_name, decode_base64(key_data["public_key"]) ), ) return + raise AuthError(401, "No signature from server %s" % (id_server,)) + async def ask_id_server_for_third_party_invite( self, requester, @@ -814,15 +955,17 @@ class IdentityHandler(BaseHandler): "sender_avatar_url": inviter_avatar_url, } + # Rewrite the identity server URL if necessary + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + # Add the identity service access token to the JSON body and use the v2 # Identity Service endpoints if id_access_token is present data = None - base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) + base_url = "%s/_matrix/identity" % (id_server_url,) if id_access_token: - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % ( + id_server_url, ) # Attempt a v2 lookup @@ -841,9 +984,8 @@ class IdentityHandler(BaseHandler): raise e if data is None: - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_url, ) url = base_url + "/api/v1/store-invite" @@ -855,10 +997,7 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( - "Error trying to call /store-invite on %s%s: %s", - id_server_scheme, - id_server, - e, + "Error trying to call /store-invite on %s: %s", id_server_url, e, ) if data is None: @@ -871,10 +1010,9 @@ class IdentityHandler(BaseHandler): ) except HttpResponseException as e: logger.warning( - "Error calling /store-invite on %s%s with fallback " + "Error calling /store-invite on %s with fallback " "encoding: %s", - id_server_scheme, - id_server, + id_server_url, e, ) raise e @@ -895,6 +1033,30 @@ class IdentityHandler(BaseHandler): display_name = data["display_name"] return token, public_keys, fallback_public_key, display_name + async def bind_email_using_internal_sydent_api( + self, id_server_url: str, email: str, user_id: str, + ): + """Bind an email to a fully qualified user ID using the internal API of an + instance of Sydent. + + Args: + id_server_url: The URL of the Sydent instance + email: The email address to bind + user_id: The user ID to bind the email to + + Raises: + HTTPResponseException: On a non-2xx HTTP response. + """ + # id_server_url is assumed to have no trailing slashes + url = id_server_url + "/_matrix/identity/internal/bind" + body = { + "address": email, + "medium": "email", + "mxid": user_id, + } + + await self.http_client.post_json_get_json(url, body) + def create_id_access_token_header(id_access_token): """Create an Authorization header for passing to SimpleHttpClient as the header value diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 649ca1f08a..665ad19b5d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -17,8 +17,6 @@ import logging from typing import Optional, Tuple -from six import iteritems, itervalues, string_types - from canonicaljson import encode_canonical_json, json from twisted.internet import defer @@ -246,7 +244,7 @@ class MessageHandler(object): "avatar_url": profile.avatar_url, "display_name": profile.display_name, } - for user_id, profile in iteritems(users_with_profile) + for user_id, profile in users_with_profile.items() } def maybe_schedule_expiry(self, event): @@ -715,7 +713,7 @@ class EventCreationHandler(object): spam_error = self.spam_checker.check_event_for_spam(event) if spam_error: - if not isinstance(spam_error, string_types): + if not isinstance(spam_error, str): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) @@ -881,7 +879,9 @@ class EventCreationHandler(object): """ room_alias = RoomAlias.from_string(room_alias_str) try: - mapping = yield directory_handler.get_association(room_alias) + mapping = yield defer.ensureDeferred( + directory_handler.get_association(room_alias) + ) except SynapseError as e: # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. if e.errcode == Codes.NOT_FOUND: @@ -988,7 +988,7 @@ class EventCreationHandler(object): state_to_include_ids = [ e_id - for k, e_id in iteritems(current_state_ids) + for k, e_id in current_state_ids.items() if k[0] in self.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -1002,7 +1002,7 @@ class EventCreationHandler(object): "content": e.content, "sender": e.sender, } - for e in itervalues(state_to_include) + for e in state_to_include.values() ] invitee = UserID.from_string(event.state_key) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index d7442c62a7..da06582d4b 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -15,9 +15,6 @@ # limitations under the License. import logging -from six import iteritems - -from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership @@ -99,8 +96,7 @@ class PaginationHandler(object): job["longest_max_lifetime"], ) - @defer.inlineCallbacks - def purge_history_for_rooms_in_range(self, min_ms, max_ms): + async def purge_history_for_rooms_in_range(self, min_ms, max_ms): """Purge outdated events from rooms within the given retention range. If a default retention policy is defined in the server's configuration and its @@ -139,13 +135,13 @@ class PaginationHandler(object): include_null, ) - rooms = yield self.store.get_rooms_for_retention_period_in_range( + rooms = await self.store.get_rooms_for_retention_period_in_range( min_ms, max_ms, include_null ) logger.debug("[purge] Rooms to purge: %s", rooms) - for room_id, retention_policy in iteritems(rooms): + for room_id, retention_policy in rooms.items(): logger.info("[purge] Attempting to purge messages in room %s", room_id) if room_id in self._purges_in_progress_by_room: @@ -167,9 +163,9 @@ class PaginationHandler(object): # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime - stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) + stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) - r = yield self.store.get_room_event_before_stream_ordering( + r = await self.store.get_room_event_before_stream_ordering( room_id, stream_ordering, ) if not r: @@ -229,8 +225,7 @@ class PaginationHandler(object): ) return purge_id - @defer.inlineCallbacks - def _purge_history(self, purge_id, room_id, token, delete_local_events): + async def _purge_history(self, purge_id, room_id, token, delete_local_events): """Carry out a history purge on a room. Args: @@ -239,14 +234,11 @@ class PaginationHandler(object): token (str): topological token to delete events before delete_local_events (bool): True to delete local events as well as remote ones - - Returns: - Deferred """ self._purges_in_progress_by_room.add(room_id) try: - with (yield self.pagination_lock.write(room_id)): - yield self.storage.purge_events.purge_history( + with await self.pagination_lock.write(room_id): + await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -284,9 +276,7 @@ class PaginationHandler(object): await self.store.get_room_version_id(room_id) # first check that we have no users in this room - joined = await defer.maybeDeferred( - self.store.is_host_joined, room_id, self._server_name - ) + joined = await self.store.is_host_joined(room_id, self._server_name) if joined: raise SynapseError(400, "Users are still joined to this room") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3594f3b00f..d2f25ae12a 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -25,9 +25,7 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import Dict, Iterable, List, Set - -from six import iteritems, itervalues +from typing import Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager @@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC): for user_id in user_ids } - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in iteritems(states) if not state] + missing = [user_id for user_id, state in states.items() if not state] if missing: new = { user_id: UserPresenceState.default(user_id) for user_id in missing @@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states( [ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) - for prev_state in itervalues(prev_states) + for prev_state in prev_states.values() ] ) self.external_process_last_updated_ms.pop(process_id, None) @@ -775,7 +773,9 @@ class PresenceHandler(BasePresenceHandler): return False - async def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -787,10 +787,31 @@ class PresenceHandler(BasePresenceHandler): - last_user_sync_ts(int) - status_msg(int) - currently_active(int) + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ + # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id, limit) + rows = await self.store.get_all_presence_updates( + instance_name, last_id, current_id, limit + ) return rows def notify_new_event(self): @@ -1087,7 +1108,7 @@ class PresenceEventSource(object): return (list(updates.values()), max_token) else: return ( - [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE], + [s for s in updates.values() if s.state != PresenceState.OFFLINE], max_token, ) @@ -1323,11 +1344,11 @@ def get_interested_remotes(store, states, state_handler): # hosts in those rooms. room_ids_to_states, users_to_states = yield get_interested_parties(store, states) - for room_id, states in iteritems(room_ids_to_states): + for room_id, states in room_ids_to_states.items(): hosts = yield state_handler.get_current_hosts_in_room(room_id) hosts_and_states.append((hosts, states)) - for user_id, states in iteritems(users_to_states): + for user_id, states in users_to_states.items(): host = get_domain_from_id(user_id) hosts_and_states.append(([host], states)) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 302efc1b9a..dd8979e750 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,10 +15,13 @@ # limitations under the License. import logging +from typing import List -from six import raise_from +from six.moves import range -from twisted.internet import defer +from signedjson.sign import sign_json + +from twisted.internet import defer, reactor from synapse.api.errors import ( AuthError, @@ -27,6 +31,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, create_requester, get_domain_from_id @@ -46,6 +51,8 @@ class BaseProfileHandler(BaseHandler): subclass MasterProfileHandler """ + PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000 + def __init__(self, hs): super(BaseProfileHandler, self).__init__(hs) @@ -56,6 +63,88 @@ class BaseProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() + + self.max_avatar_size = hs.config.max_avatar_size + self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes + self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to + + if hs.config.worker_app is None: + self.clock.looping_call( + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS + ) + + if len(self.hs.config.replicate_user_profiles_to) > 0: + reactor.callWhenRunning(self._assign_profile_replication_batches) + reactor.callWhenRunning(self._replicate_profiles) + # Add a looping call to replicate_profiles: this handles retries + # if the replication is unsuccessful when the user updated their + # profile. + self.clock.looping_call( + self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL + ) + + @defer.inlineCallbacks + def _assign_profile_replication_batches(self): + """If no profile replication has been done yet, allocate replication batch + numbers to each profile to start the replication process. + """ + logger.info("Assigning profile batch numbers...") + total = 0 + while True: + assigned = yield self.store.assign_profile_batch() + total += assigned + if assigned == 0: + break + logger.info("Assigned %d profile batch numbers", total) + + @defer.inlineCallbacks + def _replicate_profiles(self): + """If any profile data has been updated and not pushed to the replication targets, + replicate it. + """ + host_batches = yield self.store.get_replication_hosts() + latest_batch = yield self.store.get_latest_profile_replication_batch_number() + if latest_batch is None: + latest_batch = -1 + for repl_host in self.hs.config.replicate_user_profiles_to: + if repl_host not in host_batches: + host_batches[repl_host] = -1 + try: + for i in range(host_batches[repl_host] + 1, latest_batch + 1): + yield self._replicate_host_profile_batch(repl_host, i) + except Exception: + logger.exception( + "Exception while replicating to %s: aborting for now", repl_host + ) + + @defer.inlineCallbacks + def _replicate_host_profile_batch(self, host, batchnum): + logger.info("Replicating profile batch %d to %s", batchnum, host) + batch_rows = yield self.store.get_profile_batch(batchnum) + batch = { + UserID(r["user_id"], self.hs.hostname).to_string(): ( + {"display_name": r["displayname"], "avatar_url": r["avatar_url"]} + if r["active"] + else None + ) + for r in batch_rows + } + + url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,) + body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname} + signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0]) + try: + yield self.http_client.post_json_get_json(url, signed_body) + yield self.store.update_replication_batch_for_host(host, batchnum) + logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host) + except Exception: + # This will get retried when the looping call next comes around + logger.exception( + "Failed to replicate profile batch %d to %s", batchnum, host + ) + raise + @defer.inlineCallbacks def get_profile(self, user_id): target_user = UserID.from_string(user_id) @@ -84,7 +173,7 @@ class BaseProfileHandler(BaseHandler): ) return result except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -135,7 +224,7 @@ class BaseProfileHandler(BaseHandler): ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -155,7 +244,7 @@ class BaseProfileHandler(BaseHandler): if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this homeserver") - if not by_admin and target_user != requester.user: + if not by_admin and requester and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if not by_admin and not self.hs.config.enable_set_displayname: @@ -175,13 +264,23 @@ class BaseProfileHandler(BaseHandler): if new_displayname == "": new_displayname = None + if len(self.hs.config.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + # If the admin changes the display name of a user, the requesting user cannot send # the join event to update the displayname in the rooms. # This must be done by the target user himself. if by_admin: requester = create_requester(target_user) - await self.store.set_profile_displayname(target_user.localpart, new_displayname) + await self.store.set_profile_displayname( + target_user.localpart, new_displayname, new_batchnum + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) @@ -191,6 +290,50 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) + # start a profile replication push + run_in_background(self._replicate_profiles) + + @defer.inlineCallbacks + def set_active( + self, users: List[UserID], active: bool, hide: bool, + ): + """ + Sets the 'active' flag on a set of user profiles. If set to false, the + accounts are considered deactivated or hidden. + + If 'hide' is true, then we interpret active=False as a request to try to + hide the users rather than deactivating them. This means withholding the + profiles from replication (and mark it as inactive) rather than clearing + the profile from the HS DB. + + Note that unlike set_displayname and set_avatar_url, this does *not* + perform authorization checks! This is because the only place it's used + currently is in account deactivation where we've already done these + checks anyway. + + Args: + users: The users to modify + active: Whether to set the user to active or inactive + hide: Whether to hide the user (withold from replication). If + False and active is False, user will have their profile + erased + + Returns: + Deferred + """ + if len(self.replicate_user_profiles_to) > 0: + cur_batchnum = ( + yield self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + + yield self.store.set_profiles_active(users, active, hide, new_batchnum) + + # start a profile replication push + run_in_background(self._replicate_profiles) + @defer.inlineCallbacks def get_avatar_url(self, target_user): if self.hs.is_mine(target_user): @@ -212,7 +355,7 @@ class BaseProfileHandler(BaseHandler): ignore_backoff=True, ) except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to fetch profile"), e) + raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: raise e.to_synapse_error() @@ -241,11 +384,51 @@ class BaseProfileHandler(BaseHandler): 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,) ) + # Enforce a max avatar size if one is defined + if self.max_avatar_size or self.allowed_avatar_mimetypes: + media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url) + + # Check that this media exists locally + media_info = await self.store.get_local_media(media_id) + if not media_info: + raise SynapseError( + 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND + ) + + # Ensure avatar does not exceed max allowed avatar size + media_size = media_info["media_length"] + if self.max_avatar_size and media_size > self.max_avatar_size: + raise SynapseError( + 400, + "Avatars must be less than %s bytes in size" + % (self.max_avatar_size,), + errcode=Codes.TOO_LARGE, + ) + + # Ensure the avatar's file type is allowed + if ( + self.allowed_avatar_mimetypes + and media_info["media_type"] not in self.allowed_avatar_mimetypes + ): + raise SynapseError( + 400, "Avatar file type '%s' not allowed" % media_info["media_type"] + ) + # Same like set_displayname if by_admin: requester = create_requester(target_user) - await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) + if len(self.hs.config.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + + await self.store.set_profile_avatar_url( + target_user.localpart, new_avatar_url, new_batchnum + ) if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(target_user.localpart) @@ -255,6 +438,23 @@ class BaseProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) + # start a profile replication push + run_in_background(self._replicate_profiles) + + def _validate_and_parse_media_id_from_avatar_url(self, mxc): + """Validate and parse a provided avatar url and return the local media id + + Args: + mxc (str): A mxc URL + + Returns: + str: The ID of the media + """ + avatar_pieces = mxc.split("/") + if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:": + raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc) + return avatar_pieces[-1] + @defer.inlineCallbacks def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 51979ea43e..f223630d43 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -17,7 +17,7 @@ import logging from synapse import types -from synapse.api.constants import MAX_USERID_LENGTH, LoginType +from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict @@ -26,7 +26,8 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) -from synapse.types import RoomAlias, RoomID, UserID, create_requester +from synapse.storage.state import StateFilter +from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer from ._base import BaseHandler @@ -47,6 +48,7 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() self.identity_handler = self.hs.get_handlers().identity_handler self.ratelimiter = hs.get_registration_ratelimiter() @@ -59,6 +61,8 @@ class RegistrationHandler(BaseHandler): ) self._server_notices_mxid = hs.config.server_notices_mxid + self._show_in_user_directory = self.hs.config.show_users_in_user_directory + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -74,8 +78,18 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime async def check_username( - self, localpart, guest_access_token=None, assigned_user_id=None + self, localpart, guest_access_token=None, assigned_user_id=None, ): + """ + + Args: + localpart (str|None): The user's localpart + guest_access_token (str|None): A guest's access token + assigned_user_id (str|None): An existing User ID for this user if pre-calculated + + Returns: + Deferred + """ if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, @@ -118,6 +132,8 @@ class RegistrationHandler(BaseHandler): raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) + + # Retrieve guest user information from provided access token user_data = await self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( @@ -203,6 +219,11 @@ class RegistrationHandler(BaseHandler): address=address, ) + if default_display_name: + await self.profile_handler.set_displayname( + user, None, default_display_name, by_admin=True + ) + if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(localpart) await self.user_directory_handler.handle_local_profile_change( @@ -233,6 +254,10 @@ class RegistrationHandler(BaseHandler): address=address, ) + await self.profile_handler.set_displayname( + user, None, default_display_name, by_admin=True + ) + # Successfully registered break except SynapseError: @@ -266,55 +291,169 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - await self._register_email_threepid(user_id, threepid_dict, None) + await self.register_email_threepid(user_id, threepid_dict, None) + + # Prevent the new user from showing up in the user directory if the server + # mandates it. + if not self._show_in_user_directory: + await self.store.add_account_data_for_user( + user_id, "im.vector.hide_profile", {"hide_profile": True} + ) + await self.profile_handler.set_active([user], False, True) return user_id - async def _auto_join_rooms(self, user_id): - """Automatically joins users to auto join rooms - creating the room in the first place - if the user is the first to be created. + async def _create_and_join_rooms(self, user_id: str): + """ + Create the auto-join rooms and join or invite the user to them. + + This should only be called when the first "real" user registers. Args: - user_id(str): The user to join + user_id: The user to join """ - # auto-join the user to any rooms we're supposed to dump them into - fake_requester = create_requester(user_id) + # Getting the handlers during init gives a dependency loop. + room_creation_handler = self.hs.get_room_creation_handler() + room_member_handler = self.hs.get_room_member_handler() - # try to create the room if we're the first real user on the server. Note - # that an auto-generated support or bot user is not a real user and will never be - # the user to create the room - should_auto_create_rooms = False - is_real_user = await self.store.is_real_user(user_id) - if self.hs.config.autocreate_auto_join_rooms and is_real_user: - count = await self.store.count_real_users() - should_auto_create_rooms = count == 1 - for r in self.hs.config.auto_join_rooms: + # Generate a stub for how the rooms will be configured. + stub_config = { + "preset": self.hs.config.registration.autocreate_auto_join_room_preset, + } + + # If the configuration providers a user ID to create rooms with, use + # that instead of the first user registered. + requires_join = False + if self.hs.config.registration.auto_join_user_id: + fake_requester = create_requester( + self.hs.config.registration.auto_join_user_id + ) + + # If the room requires an invite, add the user to the list of invites. + if self.hs.config.registration.auto_join_room_requires_invite: + stub_config["invite"] = [user_id] + + # If the room is being created by a different user, the first user + # registered needs to join it. Note that in the case of an invitation + # being necessary this will occur after the invite was sent. + requires_join = True + else: + fake_requester = create_requester(user_id) + + # Choose whether to federate the new room. + if not self.hs.config.registration.autocreate_auto_join_rooms_federated: + stub_config["creation_content"] = {"m.federate": False} + + for r in self.hs.config.registration.auto_join_rooms: logger.info("Auto-joining %s to %s", user_id, r) + try: - if should_auto_create_rooms: - room_alias = RoomAlias.from_string(r) - if self.hs.hostname != room_alias.domain: - logger.warning( - "Cannot create room alias %s, " - "it does not match server domain", - r, - ) - else: - # create room expects the localpart of the room alias - room_alias_localpart = room_alias.localpart - - # getting the RoomCreationHandler during init gives a dependency - # loop - await self.hs.get_room_creation_handler().create_room( - fake_requester, - config={ - "preset": "public_chat", - "room_alias_name": room_alias_localpart, - }, + room_alias = RoomAlias.from_string(r) + + if self.hs.hostname != room_alias.domain: + logger.warning( + "Cannot create room alias %s, " + "it does not match server domain", + r, + ) + else: + # A shallow copy is OK here since the only key that is + # modified is room_alias_name. + config = stub_config.copy() + # create room expects the localpart of the room alias + config["room_alias_name"] = room_alias.localpart + + info, _ = await room_creation_handler.create_room( + fake_requester, config=config, ratelimit=False, + ) + + # If the room does not require an invite, but another user + # created it, then ensure the first user joins it. + if requires_join: + await room_member_handler.update_membership( + requester=create_requester(user_id), + target=UserID.from_string(user_id), + room_id=info["room_id"], + # Since it was just created, there are no remote hosts. + remote_room_hosts=[], + action="join", ratelimit=False, ) + + except ConsentNotGivenError as e: + # Technically not necessary to pull out this error though + # moving away from bare excepts is a good thing to do. + logger.error("Failed to join new user to %r: %r", r, e) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + + async def _join_rooms(self, user_id: str): + """ + Join or invite the user to the auto-join rooms. + + Args: + user_id: The user to join + """ + room_member_handler = self.hs.get_room_member_handler() + + for r in self.hs.config.registration.auto_join_rooms: + logger.info("Auto-joining %s to %s", user_id, r) + + try: + room_alias = RoomAlias.from_string(r) + + if RoomAlias.is_valid(r): + ( + room_id, + remote_room_hosts, + ) = await room_member_handler.lookup_room_alias(room_alias) + room_id = room_id.to_string() else: - await self._join_user_to_room(fake_requester, r) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (r,) + ) + + # Calculate whether the room requires an invite or can be + # joined directly. Note that unless a join rule of public exists, + # it is treated as requiring an invite. + requires_invite = True + + state = await self.store.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) + ) + + event_id = state.get((EventTypes.JoinRules, "")) + if event_id: + join_rules_event = await self.store.get_event( + event_id, allow_none=True + ) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + requires_invite = join_rule and join_rule != JoinRules.PUBLIC + + # Send the invite, if necessary. + if requires_invite: + await room_member_handler.update_membership( + requester=create_requester( + self.hs.config.registration.auto_join_user_id + ), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="invite", + ratelimit=False, + ) + + # Send the join. + await room_member_handler.update_membership( + requester=create_requester(user_id), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ratelimit=False, + ) + except ConsentNotGivenError as e: # Technically not necessary to pull out this error though # moving away from bare excepts is a good thing to do. @@ -322,6 +461,29 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) + async def _auto_join_rooms(self, user_id: str): + """Automatically joins users to auto join rooms - creating the room in the first place + if the user is the first to be created. + + Args: + user_id: The user to join + """ + # auto-join the user to any rooms we're supposed to dump them into + + # try to create the room if we're the first real user on the server. Note + # that an auto-generated support or bot user is not a real user and will never be + # the user to create the room + should_auto_create_rooms = False + is_real_user = await self.store.is_real_user(user_id) + if self.hs.config.registration.autocreate_auto_join_rooms and is_real_user: + count = await self.store.count_real_users() + should_auto_create_rooms = count == 1 + + if should_auto_create_rooms: + await self._create_and_join_rooms(user_id) + else: + await self._join_rooms(user_id) + async def post_consent_actions(self, user_id): """A series of registration actions that can only be carried out once consent has been granted @@ -331,7 +493,10 @@ class RegistrationHandler(BaseHandler): """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart, as_token): + async def appservice_register( + self, user_localpart, as_token, password_hash, display_name + ): + # FIXME: this should be factored out and merged with normal register() user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -348,12 +513,25 @@ class RegistrationHandler(BaseHandler): self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service) + display_name = display_name or user.localpart + await self.register_with_store( user_id=user_id, - password_hash="", + password_hash=password_hash, appservice_id=service_id, - create_profile_with_displayname=user.localpart, + create_profile_with_displayname=display_name, + ) + + await self.profile_handler.set_displayname( + user, None, display_name, by_admin=True ) + + if self.hs.config.user_directory_search_all_users: + profile = await self.store.get_profileinfo(user_localpart) + await self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + return user_id def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): @@ -380,6 +558,37 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) + async def shadow_register(self, localpart, display_name, auth_result, params): + """Invokes the current registration on another server, using + shared secret registration, passing in any auth_results from + other registration UI auth flows (e.g. validated 3pids) + Useful for setting up shadow/backup accounts on a parallel deployment. + """ + + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token), + { + # XXX: auth_result is an unspecified extension for shadow registration + "auth_result": auth_result, + # XXX: another unspecified extension for shadow registration to ensure + # that the displayname is correctly set by the masters erver + "display_name": display_name, + "username": localpart, + "password": params.get("password"), + "bind_msisdn": params.get("bind_msisdn"), + "device_id": params.get("device_id"), + "initial_device_display_name": params.get( + "initial_device_display_name" + ), + "inhibit_login": False, + "access_token": as_token, + }, + ) + async def _generate_user_id(self): if self._next_generated_user_id is None: with await self._generate_user_id_linearizer.queue(()): @@ -392,30 +601,6 @@ class RegistrationHandler(BaseHandler): self._next_generated_user_id += 1 return str(id) - async def _join_user_to_room(self, requester, room_identifier): - room_member_handler = self.hs.get_room_member_handler() - if RoomID.is_valid(room_identifier): - room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = await room_member_handler.lookup_room_alias( - room_alias - ) - room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - - await room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - remote_room_hosts=remote_room_hosts, - action="join", - ratelimit=False, - ) - def check_registration_ratelimit(self, address): """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -557,6 +742,7 @@ class RegistrationHandler(BaseHandler): if auth_result and LoginType.EMAIL_IDENTITY in auth_result: threepid = auth_result[LoginType.EMAIL_IDENTITY] + # Necessary due to auth checks prior to the threepid being # written to the db if is_threepid_reserved( @@ -564,7 +750,32 @@ class RegistrationHandler(BaseHandler): ): await self.store.upsert_monthly_active_user(user_id) - await self._register_email_threepid(user_id, threepid, access_token) + await self.register_email_threepid(user_id, threepid, access_token) + + if self.hs.config.bind_new_user_emails_to_sydent: + # Attempt to call Sydent's internal bind API on the given identity server + # to bind this threepid + id_server_url = self.hs.config.bind_new_user_emails_to_sydent + + logger.debug( + "Attempting the bind email of %s to identity server: %s using " + "internal Sydent bind API.", + user_id, + self.hs.config.bind_new_user_emails_to_sydent, + ) + + try: + await self.identity_handler.bind_email_using_internal_sydent_api( + id_server_url, threepid["address"], user_id + ) + except Exception as e: + logger.warning( + "Failed to bind email of '%s' to Sydent instance '%s' ", + "using Sydent internal bind API: %s", + user_id, + id_server_url, + e, + ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] @@ -585,7 +796,7 @@ class RegistrationHandler(BaseHandler): await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - async def _register_email_threepid(self, user_id, threepid, token): + async def register_email_threepid(self, user_id, threepid, token): """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 61db3ccc43..ac9a58e766 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -24,9 +24,12 @@ import string from collections import OrderedDict from typing import Tuple -from six import iteritems, string_types - -from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset +from synapse.api.constants import ( + EventTypes, + JoinRules, + RoomCreationPreset, + RoomEncryptionAlgorithms, +) from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events.utils import copy_power_levels_contents @@ -56,31 +59,6 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000 class RoomCreationHandler(BaseHandler): - - PRESETS_DICT = { - RoomCreationPreset.PRIVATE_CHAT: { - "join_rules": JoinRules.INVITE, - "history_visibility": "shared", - "original_invitees_have_ops": False, - "guest_can_join": True, - "power_level_content_override": {"invite": 0}, - }, - RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { - "join_rules": JoinRules.INVITE, - "history_visibility": "shared", - "original_invitees_have_ops": True, - "guest_can_join": True, - "power_level_content_override": {"invite": 0}, - }, - RoomCreationPreset.PUBLIC_CHAT: { - "join_rules": JoinRules.PUBLIC, - "history_visibility": "shared", - "original_invitees_have_ops": False, - "guest_can_join": False, - "power_level_content_override": {}, - }, - } - def __init__(self, hs): super(RoomCreationHandler, self).__init__(hs) @@ -89,6 +67,39 @@ class RoomCreationHandler(BaseHandler): self.room_member_handler = hs.get_room_member_handler() self.config = hs.config + # Room state based off defined presets + self._presets_dict = { + RoomCreationPreset.PRIVATE_CHAT: { + "join_rules": JoinRules.INVITE, + "history_visibility": "shared", + "original_invitees_have_ops": False, + "guest_can_join": True, + "power_level_content_override": {"invite": 0}, + }, + RoomCreationPreset.TRUSTED_PRIVATE_CHAT: { + "join_rules": JoinRules.INVITE, + "history_visibility": "shared", + "original_invitees_have_ops": True, + "guest_can_join": True, + "power_level_content_override": {"invite": 0}, + }, + RoomCreationPreset.PUBLIC_CHAT: { + "join_rules": JoinRules.PUBLIC, + "history_visibility": "shared", + "original_invitees_have_ops": False, + "guest_can_join": False, + "power_level_content_override": {}, + }, + } + + # Modify presets to selectively enable encryption by default per homeserver config + for preset_name, preset_config in self._presets_dict.items(): + encrypted = ( + preset_name + in self.config.encryption_enabled_by_default_for_room_presets + ) + preset_config["encrypted"] = encrypted + self._replication = hs.get_replication_data_handler() # linearizer to stop two upgrades happening at once @@ -324,7 +335,19 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if not self.spam_checker.user_may_create_room(user_id): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to create rooms + is_requester_admin = True + + else: + is_requester_admin = await self.auth.is_server_admin(requester.user) + + if not is_requester_admin and not self.spam_checker.user_may_create_room( + user_id, invite_list=[], third_party_invite_list=[], cloning=True + ): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { @@ -364,7 +387,7 @@ class RoomCreationHandler(BaseHandler): # map from event_id to BaseEvent old_room_state_events = await self.store.get_events(old_room_state_ids.values()) - for k, old_event_id in iteritems(old_room_state_ids): + for k, old_event_id in old_room_state_ids.items(): old_event = old_room_state_events.get(old_event_id) if old_event: initial_state[k] = old_event.content @@ -417,7 +440,7 @@ class RoomCreationHandler(BaseHandler): old_room_member_state_events = await self.store.get_events( old_room_member_state_ids.values() ) - for k, old_event in iteritems(old_room_member_state_events): + for k, old_event in old_room_member_state_events.items(): # Only transfer ban events if ( "membership" in old_event.content @@ -570,8 +593,14 @@ class RoomCreationHandler(BaseHandler): 403, "You are not permitted to create rooms", Codes.FORBIDDEN ) + invite_list = config.get("invite", []) + invite_3pid_list = config.get("invite_3pid", []) + if not is_requester_admin and not self.spam_checker.user_may_create_room( - user_id + user_id, + invite_list=invite_list, + third_party_invite_list=invite_3pid_list, + cloning=False, ): raise SynapseError(403, "You are not permitted to create rooms") @@ -582,7 +611,7 @@ class RoomCreationHandler(BaseHandler): "room_version", self.config.default_room_version.identifier ) - if not isinstance(room_version_id, string_types): + if not isinstance(room_version_id, str): raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON) room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) @@ -606,7 +635,6 @@ class RoomCreationHandler(BaseHandler): else: room_alias = None - invite_list = config.get("invite", []) for i in invite_list: try: uid = UserID.from_string(i) @@ -628,8 +656,6 @@ class RoomCreationHandler(BaseHandler): % (user_id,), ) - invite_3pid_list = config.get("invite_3pid", []) - visibility = config.get("visibility", None) is_public = visibility == "public" @@ -731,6 +757,7 @@ class RoomCreationHandler(BaseHandler): "invite", ratelimit=False, content=content, + new_room=True, ) for invite_3pid in invite_3pid_list: @@ -746,6 +773,7 @@ class RoomCreationHandler(BaseHandler): id_server, requester, txn_id=None, + new_room=True, id_access_token=id_access_token, ) @@ -798,7 +826,7 @@ class RoomCreationHandler(BaseHandler): ) return last_stream_id - config = RoomCreationHandler.PRESETS_DICT[preset_config] + config = self._presets_dict[preset_config] creator_id = creator.user.to_string() @@ -815,6 +843,7 @@ class RoomCreationHandler(BaseHandler): "join", ratelimit=False, content=creator_join_profile, + new_room=True, ) # We treat the power levels override specially as this needs to be one @@ -888,6 +917,13 @@ class RoomCreationHandler(BaseHandler): etype=etype, state_key=state_key, content=content ) + if config["encrypted"]: + last_sent_stream_id = await send( + etype=EventTypes.RoomEncryption, + state_key="", + content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, + ) + return last_sent_stream_id async def _generate_room_id( diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 4cbc02b0d0..5e05be6181 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -17,8 +17,6 @@ import logging from collections import namedtuple from typing import Any, Dict, Optional -from six import iteritems - import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -271,7 +269,7 @@ class RoomListHandler(BaseHandler): event_map = yield self.store.get_events( [ event_id - for key, event_id in iteritems(current_state_ids) + for key, event_id in current_state_ids.items() if key[0] in ( EventTypes.Create, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0f7af982f0..d585114f2d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -17,10 +17,9 @@ import abc import logging +from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Tuple -from six.moves import http_client - from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError @@ -59,6 +58,7 @@ class RoomMemberHandler(object): self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.identity_handler = hs.get_handlers().identity_handler self.member_linearizer = Linearizer(name="member") @@ -288,8 +288,10 @@ class RoomMemberHandler(object): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, ) -> Tuple[Optional[str], int]: + """Update a user's membership in a room""" key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -303,6 +305,7 @@ class RoomMemberHandler(object): third_party_signed=third_party_signed, ratelimit=ratelimit, content=content, + new_room=new_room, require_consent=require_consent, ) @@ -319,6 +322,7 @@ class RoomMemberHandler(object): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, ) -> Tuple[Optional[str], int]: content_specified = bool(content) @@ -361,7 +365,7 @@ class RoomMemberHandler(object): if effective_membership_state == Membership.INVITE: # block any attempts to invite the server notices mxid if target.to_string() == self._server_notices_mxid: - raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user") + raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -383,8 +387,15 @@ class RoomMemberHandler(object): ) block_invite = True + is_published = await self.store.is_room_published(room_id) + if not self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), + target.to_string(), + third_party_invite=None, + room_id=room_id, + new_room=new_room, + published_room=is_published, ): logger.info("Blocking invite due to spam checker") block_invite = True @@ -444,7 +455,7 @@ class RoomMemberHandler(object): is_blocked = await self._is_server_notice_room(room_id) if is_blocked: raise SynapseError( - http_client.FORBIDDEN, + HTTPStatus.FORBIDDEN, "You cannot reject this invite", errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM, ) @@ -462,6 +473,25 @@ class RoomMemberHandler(object): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to join rooms + is_requester_admin = True + + else: + is_requester_admin = await self.auth.is_server_admin(requester.user) + + inviter = await self._get_inviter(target.to_string(), room_id) + if not is_requester_admin: + # We assume that if the spam checker allowed the user to create + # a room then they're allowed to join it. + if not new_room and not self.spam_checker.user_may_join_room( + target.to_string(), room_id, is_invited=inviter is not None + ): + raise SynapseError(403, "Not allowed to join this room") + if not is_host_in_room: inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): @@ -736,6 +766,7 @@ class RoomMemberHandler(object): id_server: str, requester: Requester, txn_id: Optional[str], + new_room: bool = False, id_access_token: Optional[str] = None, ) -> int: if self.config.block_non_admin_invites: @@ -759,6 +790,16 @@ class RoomMemberHandler(object): Codes.FORBIDDEN, ) + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( + medium, address, room_id + ) + if not can_invite: + raise SynapseError( + 403, + "This third-party identifier can not be invited in this room", + Codes.FORBIDDEN, + ) + if not self._enable_lookup: raise SynapseError( 403, "Looking up third-party identifiers is denied from this server" @@ -768,6 +809,19 @@ class RoomMemberHandler(object): id_server, medium, address, id_access_token ) + is_published = await self.store.is_room_published(room_id) + + if not self.spam_checker.user_may_invite( + requester.user.to_string(), + invitee, + third_party_invite={"medium": medium, "address": address}, + room_id=room_id, + new_room=new_room, + published_room=is_published, + ): + logger.info("Blocking invite due to spam checker") + raise SynapseError(403, "Invites have been disabled on this server") + if invitee: _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 4d245b618b..5d34989f21 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6bdb24baff..4c7524493e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -18,8 +18,6 @@ import itertools import logging from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple -from six import iteritems, itervalues - import attr from prometheus_client import Counter @@ -390,7 +388,7 @@ class SyncHandler(object): # result returned by the event source is poor form (it might cache # the object) room_id = event["room_id"] - event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} + event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) receipt_key = since_token.receipt_key if since_token else "0" @@ -408,7 +406,7 @@ class SyncHandler(object): for event in receipts: room_id = event["room_id"] # exclude room id, as above - event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} + event_copy = {k: v for (k, v) in event.items() if k != "room_id"} ephemeral_by_room.setdefault(room_id, []).append(event_copy) return now_token, ephemeral_by_room @@ -454,7 +452,7 @@ class SyncHandler(object): current_state_ids_map = await self.state.get_current_state_ids( room_id ) - current_state_ids = frozenset(itervalues(current_state_ids_map)) + current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( self.storage, @@ -509,7 +507,7 @@ class SyncHandler(object): current_state_ids_map = await self.state.get_current_state_ids( room_id ) - current_state_ids = frozenset(itervalues(current_state_ids_map)) + current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( self.storage, @@ -909,7 +907,7 @@ class SyncHandler(object): logger.debug("filtering state from %r...", state_ids) state_ids = { t: event_id - for t, event_id in iteritems(state_ids) + for t, event_id in state_ids.items() if cache.get(t[1]) != event_id } logger.debug("...to %r", state_ids) @@ -1430,7 +1428,7 @@ class SyncHandler(object): if since_token: for joined_sync in sync_result_builder.joined: it = itertools.chain( - joined_sync.timeline.events, itervalues(joined_sync.state) + joined_sync.timeline.events, joined_sync.state.values() ) for event in it: if event.type == EventTypes.Member: @@ -1505,7 +1503,7 @@ class SyncHandler(object): newly_left_rooms = [] room_entries = [] invited = [] - for room_id, events in iteritems(mem_change_events_by_room_id): + for room_id, events in mem_change_events_by_room_id.items(): logger.debug( "Membership changes in %s: [%s]", room_id, @@ -1993,17 +1991,17 @@ def _calculate_state( event_id_to_key = { e: key for key, e in itertools.chain( - iteritems(timeline_contains), - iteritems(previous), - iteritems(timeline_start), - iteritems(current), + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(itervalues(current)) - ts_ids = set(itervalues(timeline_start)) - p_ids = set(itervalues(previous)) - tc_ids = set(itervalues(timeline_contains)) + c_ids = set(current.values()) + ts_ids = set(timeline_start.values()) + p_ids = set(previous.values()) + tc_ids = set(timeline_contains.values()) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, @@ -2017,7 +2015,7 @@ def _calculate_state( if lazy_load_members: p_ids.difference_update( - e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member + e for t, e in timeline_start.items() if t[0] == EventTypes.Member ) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c7bc14c623..6c7abaa578 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -15,9 +15,7 @@ import logging from collections import namedtuple -from typing import List - -from twisted.internet import defer +from typing import List, Tuple from synapse.api.errors import AuthError, SynapseError from synapse.logging.context import run_in_background @@ -115,8 +113,7 @@ class TypingHandler(object): def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) - @defer.inlineCallbacks - def started_typing(self, target_user, auth_user, room_id, timeout): + async def started_typing(self, target_user, auth_user, room_id, timeout): target_user_id = target_user.to_string() auth_user_id = auth_user.to_string() @@ -126,7 +123,7 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -145,8 +142,7 @@ class TypingHandler(object): self._push_update(member=member, typing=True) - @defer.inlineCallbacks - def stopped_typing(self, target_user, auth_user, room_id): + async def stopped_typing(self, target_user, auth_user, room_id): target_user_id = target_user.to_string() auth_user_id = auth_user.to_string() @@ -156,7 +152,7 @@ class TypingHandler(object): if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, target_user_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id) @@ -164,12 +160,11 @@ class TypingHandler(object): self._stopped_typing(member) - @defer.inlineCallbacks def user_left_room(self, user, room_id): user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) - yield self._stopped_typing(member) + self._stopped_typing(member) def _stopped_typing(self, member): if member.user_id not in self._room_typing.get(member.room_id, set()): @@ -188,10 +183,9 @@ class TypingHandler(object): self._push_update_local(member=member, typing=typing) - @defer.inlineCallbacks - def _push_remote(self, member, typing): + async def _push_remote(self, member, typing): try: - users = yield self.state.get_current_users_in_room(member.room_id) + users = await self.state.get_current_users_in_room(member.room_id) self._member_last_federation_poke[member] = self.clock.time_msec() now = self.clock.time_msec() @@ -215,8 +209,7 @@ class TypingHandler(object): except Exception: logger.exception("Error pushing typing notif to remotes") - @defer.inlineCallbacks - def _recv_edu(self, origin, content): + async def _recv_edu(self, origin, content): room_id = content["room_id"] user_id = content["user_id"] @@ -231,7 +224,7 @@ class TypingHandler(object): ) return - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: @@ -259,14 +252,31 @@ class TypingHandler(object): ) async def get_all_typing_updates( - self, last_id: int, current_id: int, limit: int - ) -> List[dict]: - """Get up to `limit` typing updates between the given tokens, earliest - updates first. + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for typing replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return [] + return [], current_id, False changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id @@ -280,9 +290,16 @@ class TypingHandler(object): serial = self._room_serials[room_id] if last_id < serial <= current_id: typing = self._room_typing[room_id] - rows.append((serial, room_id, list(typing))) + rows.append((serial, [room_id, list(typing)])) rows.sort() - return rows[:limit] + + limited = False + if len(rows) > limit: + rows = rows[:limit] + current_id = rows[-1][0] + limited = True + + return rows, current_id, limited def get_current_token(self): return self._latest_room_serial @@ -306,7 +323,7 @@ class TypingNotificationEventSource(object): "content": {"user_ids": list(typing)}, } - def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -320,7 +337,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - return defer.succeed((events, handler._latest_room_serial)) + return (events, handler._latest_room_serial) def get_current_key(self): return self.get_typing_handler()._latest_room_serial diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 12423b909a..521b6d620d 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -15,8 +15,6 @@ import logging -from six import iteritems, iterkeys - import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler @@ -289,7 +287,7 @@ class UserDirectoryHandler(StateDeltasHandler): users_with_profile = await self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. - for user_id in iterkeys(users_with_profile): + for user_id in users_with_profile.keys(): await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. @@ -298,7 +296,7 @@ class UserDirectoryHandler(StateDeltasHandler): # which when ran over an entire room, will result in the same values # being added multiple times. The batching upserts shouldn't make this # too bad, though. - for user_id, profile in iteritems(users_with_profile): + for user_id, profile in users_with_profile.items(): await self._handle_new_user(room_id, user_id, profile) async def _handle_new_user(self, room_id, user_id, profile):