diff options
124 files changed, 2614 insertions, 1328 deletions
diff --git a/INSTALL.md b/INSTALL.md index d405d9fe55..b9e3f613d1 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -151,29 +151,15 @@ sudo pacman -S base-devel python python-pip \ ##### CentOS/Fedora -Installing prerequisites on CentOS 8 or Fedora>26: +Installing prerequisites on CentOS or Fedora Linux: ```sh sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ - libwebp-devel tk-devel redhat-rpm-config \ - python3-virtualenv libffi-devel openssl-devel + libwebp-devel libxml2-devel libxslt-devel libpq-devel \ + python3-virtualenv libffi-devel openssl-devel python3-devel sudo dnf groupinstall "Development Tools" ``` -Installing prerequisites on CentOS 7 or Fedora<=25: - -```sh -sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ - lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ - python3-virtualenv libffi-devel openssl-devel -sudo yum groupinstall "Development Tools" -``` - -Note that Synapse does not support versions of SQLite before 3.11, and CentOS 7 -uses SQLite 3.7. You may be able to work around this by installing a more -recent SQLite version, but it is recommended that you instead use a Postgres -database: see [docs/postgres.md](docs/postgres.md). - ##### macOS Installing prerequisites on macOS: diff --git a/changelog.d/9003.misc b/changelog.d/9003.misc new file mode 100644 index 0000000000..557c8b2353 --- /dev/null +++ b/changelog.d/9003.misc @@ -0,0 +1 @@ +Fix 'object name reserved for internal use' errors with recent versions of SQLite. diff --git a/changelog.d/9123.misc b/changelog.d/9123.misc new file mode 100644 index 0000000000..329600c40c --- /dev/null +++ b/changelog.d/9123.misc @@ -0,0 +1 @@ +Add experimental support for running Synapse with PyPy. diff --git a/changelog.d/9150.feature b/changelog.d/9150.feature new file mode 100644 index 0000000000..48a8148dee --- /dev/null +++ b/changelog.d/9150.feature @@ -0,0 +1 @@ +New API /_synapse/admin/rooms/{roomId}/context/{eventId}. diff --git a/changelog.d/9240.misc b/changelog.d/9240.misc new file mode 100644 index 0000000000..850201f6cd --- /dev/null +++ b/changelog.d/9240.misc @@ -0,0 +1 @@ +Deny access to additional IP addresses by default. diff --git a/changelog.d/9257.bugfix b/changelog.d/9257.bugfix new file mode 100644 index 0000000000..5d0bd88dce --- /dev/null +++ b/changelog.d/9257.bugfix @@ -0,0 +1 @@ +Fix long-standing bug where sending email push would fail for rooms that the server had since left. diff --git a/changelog.d/9291.doc b/changelog.d/9291.doc new file mode 100644 index 0000000000..422acd3891 --- /dev/null +++ b/changelog.d/9291.doc @@ -0,0 +1 @@ +Add note to `auto_join_rooms` config option explaining existing rooms must be publicly joinable. diff --git a/changelog.d/9296.bugfix b/changelog.d/9296.bugfix new file mode 100644 index 0000000000..d723f8c5bd --- /dev/null +++ b/changelog.d/9296.bugfix @@ -0,0 +1 @@ +Fix bug in Synapse 1.27.0rc1 which meant the "session expired" error page during SSO registration was badly formatted. diff --git a/changelog.d/9299.misc b/changelog.d/9299.misc new file mode 100644 index 0000000000..c883a677ed --- /dev/null +++ b/changelog.d/9299.misc @@ -0,0 +1 @@ +Update the `Cursor` type hints to better match PEP 249. diff --git a/changelog.d/9300.feature b/changelog.d/9300.feature new file mode 100644 index 0000000000..a2d0b27da4 --- /dev/null +++ b/changelog.d/9300.feature @@ -0,0 +1 @@ +Further improvements to the user experience of registration via single sign-on. diff --git a/changelog.d/9301.feature b/changelog.d/9301.feature new file mode 100644 index 0000000000..a2d0b27da4 --- /dev/null +++ b/changelog.d/9301.feature @@ -0,0 +1 @@ +Further improvements to the user experience of registration via single sign-on. diff --git a/changelog.d/9305.misc b/changelog.d/9305.misc new file mode 100644 index 0000000000..456bfbfdd7 --- /dev/null +++ b/changelog.d/9305.misc @@ -0,0 +1 @@ +Add debug logging for SRV lookups. Contributed by @Bubu. diff --git a/changelog.d/9307.misc b/changelog.d/9307.misc new file mode 100644 index 0000000000..2f54d1ad07 --- /dev/null +++ b/changelog.d/9307.misc @@ -0,0 +1 @@ +Improve logging for OIDC login flow. diff --git a/changelog.d/9308.doc b/changelog.d/9308.doc new file mode 100644 index 0000000000..847f2908af --- /dev/null +++ b/changelog.d/9308.doc @@ -0,0 +1 @@ +Correct name of Synapse's service file in TURN howto. diff --git a/changelog.d/9311.feature b/changelog.d/9311.feature new file mode 100644 index 0000000000..293f2118e5 --- /dev/null +++ b/changelog.d/9311.feature @@ -0,0 +1 @@ +Add hook to spam checker modules that allow checking file uploads and remote downloads. diff --git a/changelog.d/9317.doc b/changelog.d/9317.doc new file mode 100644 index 0000000000..f4d508e090 --- /dev/null +++ b/changelog.d/9317.doc @@ -0,0 +1 @@ +Fix the braces in the `oidc_providers` section of the sample config. diff --git a/changelog.d/9321.bugfix b/changelog.d/9321.bugfix new file mode 100644 index 0000000000..52eed80969 --- /dev/null +++ b/changelog.d/9321.bugfix @@ -0,0 +1 @@ +Assert a maximum length for the `client_secret` parameter for spec compliance. diff --git a/changelog.d/9322.doc b/changelog.d/9322.doc new file mode 100644 index 0000000000..c393a3a299 --- /dev/null +++ b/changelog.d/9322.doc @@ -0,0 +1 @@ +Update installation instructions on Fedora. diff --git a/changelog.d/9326.misc b/changelog.d/9326.misc new file mode 100644 index 0000000000..768c18d27e --- /dev/null +++ b/changelog.d/9326.misc @@ -0,0 +1 @@ +Share the code for handling required attributes between the CAS and SAML handlers. diff --git a/changelog.d/9333.bugfix b/changelog.d/9333.bugfix new file mode 100644 index 0000000000..c34ba378c5 --- /dev/null +++ b/changelog.d/9333.bugfix @@ -0,0 +1 @@ +Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.". diff --git a/changelog.d/9361.bugfix b/changelog.d/9361.bugfix new file mode 100644 index 0000000000..4d0477f033 --- /dev/null +++ b/changelog.d/9361.bugfix @@ -0,0 +1 @@ +Fix a bug causing Synapse to impose the wrong type constraints on fields when processing responses from appservices to `/_matrix/app/v1/thirdparty/user/{protocol}`. diff --git a/changelog.d/9377.misc b/changelog.d/9377.misc new file mode 100644 index 0000000000..df1348ec42 --- /dev/null +++ b/changelog.d/9377.misc @@ -0,0 +1 @@ +Convert tests to use `HomeserverTestCase`. diff --git a/changelog.d/9391.bugfix b/changelog.d/9391.bugfix new file mode 100644 index 0000000000..b5e68e2ac7 --- /dev/null +++ b/changelog.d/9391.bugfix @@ -0,0 +1 @@ +Fix bug where Synapse would occaisonally stop reconnecting after the connection was lost. diff --git a/changelog.d/9394.misc b/changelog.d/9394.misc new file mode 100644 index 0000000000..b3e90143cc --- /dev/null +++ b/changelog.d/9394.misc @@ -0,0 +1 @@ +Remove some dead code from the acceptance of room invites path. \ No newline at end of file diff --git a/changelog.d/9395.bugfix b/changelog.d/9395.bugfix new file mode 100644 index 0000000000..d45cc4ffb9 --- /dev/null +++ b/changelog.d/9395.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when upgrading a room: "TypeError: '>' not supported between instances of 'NoneType' and 'int'". diff --git a/changelog.d/9396.misc b/changelog.d/9396.misc new file mode 100644 index 0000000000..df1348ec42 --- /dev/null +++ b/changelog.d/9396.misc @@ -0,0 +1 @@ +Convert tests to use `HomeserverTestCase`. diff --git a/changelog.d/9407.doc b/changelog.d/9407.doc new file mode 100644 index 0000000000..36979bc0d8 --- /dev/null +++ b/changelog.d/9407.doc @@ -0,0 +1 @@ +Document that pusher instances are shardable. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 3832b36407..bc737b30f5 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -10,6 +10,7 @@ * [Undoing room shutdowns](#undoing-room-shutdowns) - [Make Room Admin API](#make-room-admin-api) - [Forward Extremities Admin API](#forward-extremities-admin-api) +- [Event Context API](#event-context-api) # List Room API @@ -594,3 +595,121 @@ that were deleted. "deleted": 1 } ``` + +# Event Context API + +This API lets a client find the context of an event. This is designed primarily to investigate abuse reports. + +``` +GET /_synapse/admin/v1/rooms/<room_id>/context/<event_id> +``` + +This API mimmicks [GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-context-eventid). Please refer to the link for all details on parameters and reseponse. + +Example response: + +```json +{ + "end": "t29-57_2_0_2", + "events_after": [ + { + "content": { + "body": "This is an example text message", + "msgtype": "m.text", + "format": "org.matrix.custom.html", + "formatted_body": "<b>This is an example text message</b>" + }, + "type": "m.room.message", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + } + } + ], + "event": { + "content": { + "body": "filename.jpg", + "info": { + "h": 398, + "w": 394, + "mimetype": "image/jpeg", + "size": 31037 + }, + "url": "mxc://example.org/JWEIFJgwEIhweiWJE", + "msgtype": "m.image" + }, + "type": "m.room.message", + "event_id": "$f3h4d129462ha:example.com", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + } + }, + "events_before": [ + { + "content": { + "body": "something-important.doc", + "filename": "something-important.doc", + "info": { + "mimetype": "application/msword", + "size": 46144 + }, + "msgtype": "m.file", + "url": "mxc://example.org/FHyPlCeYUSFFxlgbQYZmoEoe" + }, + "type": "m.room.message", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + } + } + ], + "start": "t27-54_2_0_2", + "state": [ + { + "content": { + "creator": "@example:example.org", + "room_version": "1", + "m.federate": true, + "predecessor": { + "event_id": "$something:example.org", + "room_id": "!oldroom:example.org" + } + }, + "type": "m.room.create", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + }, + "state_key": "" + }, + { + "content": { + "membership": "join", + "avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF", + "displayname": "Alice Margatroid" + }, + "type": "m.room.member", + "event_id": "$143273582443PhrSn:example.org", + "room_id": "!636q39766251:example.com", + "sender": "@example:example.org", + "origin_server_ts": 1432735824653, + "unsigned": { + "age": 1234 + }, + "state_key": "@alice:example.org" + } + ] +} +``` diff --git a/docs/openid.md b/docs/openid.md index 9d19368845..0be79591b1 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -388,3 +388,25 @@ oidc_providers: localpart_template: "{{ user.login }}" display_name_template: "{{ user.full_name }}" ``` + +### XWiki + +Install [OpenID Connect Provider](https://extensions.xwiki.org/xwiki/bin/view/Extension/OpenID%20Connect/OpenID%20Connect%20Provider/) extension in your [XWiki](https://www.xwiki.org) instance. + +Synapse config: + +```yaml +oidc_providers: + - idp_id: xwiki + idp_name: "XWiki" + issuer: "https://myxwikihost/xwiki/oidc/" + client_id: "your-client-id" # TO BE FILLED + # Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed + client_secret: "dontcare" + scopes: ["openid", "profile"] + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" +``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index d395da11b4..13a6f045f9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -165,6 +165,7 @@ pid_file: DATADIR/homeserver.pid # - '100.64.0.0/10' # - '192.0.0.0/24' # - '169.254.0.0/16' +# - '192.88.99.0/24' # - '198.18.0.0/15' # - '192.0.2.0/24' # - '198.51.100.0/24' @@ -173,6 +174,9 @@ pid_file: DATADIR/homeserver.pid # - '::1/128' # - 'fe80::/10' # - 'fc00::/7' +# - '2001:db8::/32' +# - 'ff00::/8' +# - 'fec0::/10' # List of IP address CIDR ranges that should be allowed for federation, # identity servers, push servers, and for checking key validity for @@ -990,6 +994,7 @@ media_store_path: "DATADIR/media_store" # - '100.64.0.0/10' # - '192.0.0.0/24' # - '169.254.0.0/16' +# - '192.88.99.0/24' # - '198.18.0.0/15' # - '192.0.2.0/24' # - '198.51.100.0/24' @@ -998,6 +1003,9 @@ media_store_path: "DATADIR/media_store" # - '::1/128' # - 'fe80::/10' # - 'fc00::/7' +# - '2001:db8::/32' +# - 'ff00::/8' +# - 'fec0::/10' # List of IP address CIDR ranges that the URL preview spider is allowed # to access even if they are specified in url_preview_ip_range_blacklist. @@ -1318,6 +1326,8 @@ account_threepid_delegates: # By default, any room aliases included in this list will be created # as a publicly joinable room when the first user registers for the # homeserver. This behaviour can be customised with the settings below. +# If the room already exists, make certain it is a publicly joinable +# room. The join rule of the room must be set to 'public'. # #auto_join_rooms: # - "#example:example.com" @@ -1860,9 +1870,9 @@ oidc_providers: # user_mapping_provider: # config: # subject_claim: "id" - # localpart_template: "{ user.login }" - # display_name_template: "{ user.name }" - # email_template: "{ user.email }" + # localpart_template: "{{ user.login }}" + # display_name_template: "{{ user.name }}" + # email_template: "{{ user.email }}" # For use with Keycloak # @@ -1889,8 +1899,8 @@ oidc_providers: # user_mapping_provider: # config: # subject_claim: "id" - # localpart_template: "{ user.login }" - # display_name_template: "{ user.name }" + # localpart_template: "{{ user.login }}" + # display_name_template: "{{ user.name }}" # Enable Central Authentication Service (CAS) for registration and login. @@ -2222,7 +2232,7 @@ ui_auth: # session to be active. # # This defaults to 0, meaning the user is queried for their credentials - # before every action, but this can be overridden to alow a single + # before every action, but this can be overridden to allow a single # validation to be re-used. This weakens the protections afforded by # the user-interactive authentication process, by allowing for multiple # (and potentially different) operations to use the same validation session. diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 5b4f6428e6..47a27bf85c 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -61,6 +61,9 @@ class ExampleSpamChecker: async def check_registration_for_spam(self, email_threepid, username, request_info): return RegistrationBehaviour.ALLOW # allow all registrations + + async def check_media_file_for_spam(self, file_wrapper, file_info): + return False # allow all media ``` ## Configuration diff --git a/docs/turn-howto.md b/docs/turn-howto.md index e8f13ad484..41738bbe69 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -187,7 +187,7 @@ After updating the homeserver configuration, you must restart synapse: ``` * If you use systemd: ``` - systemctl restart synapse.service + systemctl restart matrix-synapse.service ``` ... and then reload any clients (or wait an hour for them to refresh their settings). diff --git a/docs/workers.md b/docs/workers.md index f7fc6df119..9bda0f8c23 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -373,7 +373,15 @@ Handles sending push notifications to sygnal and email. Doesn't handle any REST endpoints itself, but you should set `start_pushers: False` in the shared configuration file to stop the main synapse sending push notifications. -Note this worker cannot be load-balanced: only one instance should be active. +To run multiple instances at once the `pusher_instances` option should list all +pusher instances by their worker name, e.g.: + +```yaml +pusher_instances: + - pusher_worker1 + - pusher_worker2 +``` + ### `synapse.app.appservice` diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh index 60e8970a35..b8d1e636f1 100755 --- a/scripts-dev/make_full_schema.sh +++ b/scripts-dev/make_full_schema.sh @@ -162,12 +162,23 @@ else fi # Delete schema_version, applied_schema_deltas and applied_module_schemas tables +# Also delete any shadow tables from fts4 # This needs to be done after synapse_port_db is run echo "Dropping unwanted db tables..." SQL=" DROP TABLE schema_version; DROP TABLE applied_schema_deltas; DROP TABLE applied_module_schemas; +DROP TABLE event_search_content; +DROP TABLE event_search_segments; +DROP TABLE event_search_segdir; +DROP TABLE event_search_docsize; +DROP TABLE event_search_stat; +DROP TABLE user_directory_search_content; +DROP TABLE user_directory_search_segments; +DROP TABLE user_directory_search_segdir; +DROP TABLE user_directory_search_docsize; +DROP TABLE user_directory_search_stat; " sqlite3 "$SQLITE_DB" <<< "$SQL" psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL" diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e366a982b8..11aee50f7a 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -76,9 +76,6 @@ def _is_valid_3pe_result(r, field): fields = r["fields"] if not isinstance(fields, dict): return False - for k in fields.keys(): - if not isinstance(fields[k], str): - return False return True diff --git a/synapse/config/auth.py b/synapse/config/auth.py index 2b3e2ce87b..1f4c090cde 100644 --- a/synapse/config/auth.py +++ b/synapse/config/auth.py @@ -98,7 +98,7 @@ class AuthConfig(Config): # session to be active. # # This defaults to 0, meaning the user is queried for their credentials - # before every action, but this can be overridden to alow a single + # before every action, but this can be overridden to allow a single # validation to be re-used. This weakens the protections afforded by # the user-interactive authentication process, by allowing for multiple # (and potentially different) operations to use the same validation session. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index aaa7eba110..dbf5085965 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List + +from synapse.config.sso import SsoAttributeRequirement + from ._base import Config, ConfigError +from ._util import validate_config class CasConfig(Config): @@ -40,12 +45,16 @@ class CasConfig(Config): # TODO Update this to a _synapse URL. self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" self.cas_displayname_attribute = cas_config.get("displayname_attribute") - self.cas_required_attributes = cas_config.get("required_attributes") or {} + required_attributes = cas_config.get("required_attributes") or {} + self.cas_required_attributes = _parsed_required_attributes_def( + required_attributes + ) + else: self.cas_server_url = None self.cas_service_url = None self.cas_displayname_attribute = None - self.cas_required_attributes = {} + self.cas_required_attributes = [] def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ @@ -77,3 +86,22 @@ class CasConfig(Config): # userGroup: "staff" # department: None """ + + +# CAS uses a legacy required attributes mapping, not the one provided by +# SsoAttributeRequirement. +REQUIRED_ATTRIBUTES_SCHEMA = { + "type": "object", + "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]}, +} + + +def _parsed_required_attributes_def( + required_attributes: Any, +) -> List[SsoAttributeRequirement]: + validate_config( + REQUIRED_ATTRIBUTES_SCHEMA, + required_attributes, + config_path=("cas_config", "required_attributes"), + ) + return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()] diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 4d0f24a9d5..d081f36fa5 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -201,9 +201,9 @@ class OIDCConfig(Config): # user_mapping_provider: # config: # subject_claim: "id" - # localpart_template: "{{ user.login }}" - # display_name_template: "{{ user.name }}" - # email_template: "{{ user.email }}" + # localpart_template: "{{{{ user.login }}}}" + # display_name_template: "{{{{ user.name }}}}" + # email_template: "{{{{ user.email }}}}" # For use with Keycloak # @@ -230,8 +230,8 @@ class OIDCConfig(Config): # user_mapping_provider: # config: # subject_claim: "id" - # localpart_template: "{{ user.login }}" - # display_name_template: "{{ user.name }}" + # localpart_template: "{{{{ user.login }}}}" + # display_name_template: "{{{{ user.name }}}}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) diff --git a/synapse/config/registration.py b/synapse/config/registration.py index eb650af7fb..ead007ba5a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -391,6 +391,8 @@ class RegistrationConfig(Config): # By default, any room aliases included in this list will be created # as a publicly joinable room when the first user registers for the # homeserver. This behaviour can be customised with the settings below. + # If the room already exists, make certain it is a publicly joinable + # room. The join rule of the room must be set to 'public'. # #auto_join_rooms: # - "#example:example.com" diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 850ac3ebd6..fcaea8fb93 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -17,9 +17,7 @@ import os from collections import namedtuple from typing import Dict, List -from netaddr import IPSet - -from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST +from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module @@ -187,16 +185,17 @@ class ContentRepositoryConfig(Config): "to work" ) - self.url_preview_ip_range_blacklist = IPSet( - config["url_preview_ip_range_blacklist"] - ) - # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. - self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"]) + self.url_preview_ip_range_blacklist = generate_ip_set( + config["url_preview_ip_range_blacklist"], + ["0.0.0.0", "::"], + config_path=("url_preview_ip_range_blacklist",), + ) - self.url_preview_ip_range_whitelist = IPSet( - config.get("url_preview_ip_range_whitelist", ()) + self.url_preview_ip_range_whitelist = generate_ip_set( + config.get("url_preview_ip_range_whitelist", ()), + config_path=("url_preview_ip_range_whitelist",), ) self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 9a3e1c3e7d..2dd719c388 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -123,7 +123,7 @@ class RoomDirectoryConfig(Config): alias (str) Returns: - boolean: True if user is allowed to crate the alias + boolean: True if user is allowed to create the alias """ for rule in self._alias_creation_rules: if rule.matches(user_id, room_id, [alias]): diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 7226abd829..4b494f217f 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -17,8 +17,7 @@ import logging from typing import Any, List -import attr - +from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module @@ -398,32 +397,18 @@ class SAML2Config(Config): } -@attr.s(frozen=True) -class SamlAttributeRequirement: - """Object describing a single requirement for SAML attributes.""" - - attribute = attr.ib(type=str) - value = attr.ib(type=str) - - JSON_SCHEMA = { - "type": "object", - "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, - "required": ["attribute", "value"], - } - - ATTRIBUTE_REQUIREMENTS_SCHEMA = { "type": "array", - "items": SamlAttributeRequirement.JSON_SCHEMA, + "items": SsoAttributeRequirement.JSON_SCHEMA, } def _parse_attribute_requirements_def( attribute_requirements: Any, -) -> List[SamlAttributeRequirement]: +) -> List[SsoAttributeRequirement]: validate_config( ATTRIBUTE_REQUIREMENTS_SCHEMA, attribute_requirements, - config_path=["saml2_config", "attribute_requirements"], + config_path=("saml2_config", "attribute_requirements"), ) - return [SamlAttributeRequirement(**x) for x in attribute_requirements] + return [SsoAttributeRequirement(**x) for x in attribute_requirements] diff --git a/synapse/config/server.py b/synapse/config/server.py index 5d72cf2d82..a635b8a7dc 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging import os.path import re @@ -23,7 +24,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml -from netaddr import IPSet +from netaddr import AddrFormatError, IPNetwork, IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.util.stringutils import parse_and_validate_server_name @@ -40,6 +41,66 @@ logger = logging.Logger(__name__) # in the list. DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] + +def _6to4(network: IPNetwork) -> IPNetwork: + """Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056.""" + + # 6to4 networks consist of: + # * 2002 as the first 16 bits + # * The first IPv4 address in the network hex-encoded as the next 32 bits + # * The new prefix length needs to include the bits from the 2002 prefix. + hex_network = hex(network.first)[2:] + hex_network = ("0" * (8 - len(hex_network))) + hex_network + return IPNetwork( + "2002:%s:%s::/%d" % (hex_network[:4], hex_network[4:], 16 + network.prefixlen,) + ) + + +def generate_ip_set( + ip_addresses: Optional[Iterable[str]], + extra_addresses: Optional[Iterable[str]] = None, + config_path: Optional[Iterable[str]] = None, +) -> IPSet: + """ + Generate an IPSet from a list of IP addresses or CIDRs. + + Additionally, for each IPv4 network in the list of IP addresses, also + includes the corresponding IPv6 networks. + + This includes: + + * IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1) + * IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2) + * 6to4 Address (see RFC 3056, section 2) + + Args: + ip_addresses: An iterable of IP addresses or CIDRs. + extra_addresses: An iterable of IP addresses or CIDRs. + config_path: The path in the configuration for error messages. + + Returns: + A new IP set. + """ + result = IPSet() + for ip in itertools.chain(ip_addresses or (), extra_addresses or ()): + try: + network = IPNetwork(ip) + except AddrFormatError as e: + raise ConfigError( + "Invalid IP range provided: %s." % (ip,), config_path + ) from e + result.add(network) + + # It is possible that these already exist in the set, but that's OK. + if ":" not in str(network): + result.add(IPNetwork(network).ipv6(ipv4_compatible=True)) + result.add(IPNetwork(network).ipv6(ipv4_compatible=False)) + result.add(_6to4(network)) + + return result + + +# IP ranges that are considered private / unroutable / don't make sense. DEFAULT_IP_RANGE_BLACKLIST = [ # Localhost "127.0.0.0/8", @@ -53,6 +114,8 @@ DEFAULT_IP_RANGE_BLACKLIST = [ "192.0.0.0/24", # Link-local networks. "169.254.0.0/16", + # Formerly used for 6to4 relay. + "192.88.99.0/24", # Testing networks. "198.18.0.0/15", "192.0.2.0/24", @@ -66,6 +129,12 @@ DEFAULT_IP_RANGE_BLACKLIST = [ "fe80::/10", # Unique local addresses. "fc00::/7", + # Testing networks. + "2001:db8::/32", + # Multicast. + "ff00::/8", + # Site-local addresses + "fec0::/10", ] DEFAULT_ROOM_VERSION = "6" @@ -290,17 +359,15 @@ class ServerConfig(Config): ) # Attempt to create an IPSet from the given ranges - try: - self.ip_range_blacklist = IPSet(ip_range_blacklist) - except Exception as e: - raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e + # Always blacklist 0.0.0.0, :: - self.ip_range_blacklist.update(["0.0.0.0", "::"]) + self.ip_range_blacklist = generate_ip_set( + ip_range_blacklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",) + ) - try: - self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ())) - except Exception as e: - raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e + self.ip_range_whitelist = generate_ip_set( + config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",) + ) # The federation_ip_range_blacklist is used for backwards-compatibility # and only applies to federation and identity servers. If it is not given, @@ -308,14 +375,12 @@ class ServerConfig(Config): federation_ip_range_blacklist = config.get( "federation_ip_range_blacklist", ip_range_blacklist ) - try: - self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist." - ) from e # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + self.federation_ip_range_blacklist = generate_ip_set( + federation_ip_range_blacklist, + ["0.0.0.0", "::"], + config_path=("federation_ip_range_blacklist",), + ) if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 19bdfd462b..07ba217f89 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,11 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, Optional + +import attr from ._base import Config +@attr.s(frozen=True) +class SsoAttributeRequirement: + """Object describing a single requirement for SSO attributes.""" + + attribute = attr.ib(type=str) + # If a value is not given, than the attribute must simply exist. + value = attr.ib(type=Optional[str]) + + JSON_SCHEMA = { + "type": "object", + "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, + "required": ["attribute", "value"], + } + + class SSOConfig(Config): """SSO Configuration """ diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index e7e3a7b9a4..8cfc0bb3cb 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -17,6 +17,8 @@ import inspect from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from synapse.rest.media.v1._base import FileInfo +from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection from synapse.util.async_helpers import maybe_awaitable @@ -214,3 +216,48 @@ class SpamChecker: return behaviour return RegistrationBehaviour.ALLOW + + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> bool: + """Checks if a piece of newly uploaded media should be blocked. + + This will be called for local uploads, downloads of remote media, each + thumbnail generated for those, and web pages/images used for URL + previews. + + Note that care should be taken to not do blocking IO operations in the + main thread. For example, to get the contents of a file a module + should do:: + + async def check_media_file_for_spam( + self, file: ReadableFileWrapper, file_info: FileInfo + ) -> bool: + buffer = BytesIO() + await file.write_chunks_to(buffer.write) + + if buffer.getvalue() == b"Hello World": + return True + + return False + + + Args: + file: An object that allows reading the contents of the media. + file_info: Metadata about the file. + + Returns: + True if the media should be blocked or False if it should be + allowed. + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_media_file_for_spam", None) + if checker: + spam = await maybe_awaitable(checker(file_wrapper, file_info)) + if spam: + return True + + return False diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 0d042cbfac..76bf52ea23 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -18,6 +18,7 @@ import logging from synapse.api.errors import Codes, SynapseError +from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute @@ -32,6 +33,11 @@ logger = logging.getLogger(__name__) # TODO: Flairs +# Note that the maximum lengths are somewhat arbitrary. +MAX_SHORT_DESC_LEN = 1000 +MAX_LONG_DESC_LEN = 10000 + + class GroupsServerWorkerHandler: def __init__(self, hs): self.hs = hs @@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler): ) profile = {} - for keyname in ("name", "avatar_url", "short_description", "long_description"): + for keyname, max_length in ( + ("name", MAX_DISPLAYNAME_LEN), + ("avatar_url", MAX_AVATAR_URL_LEN), + ("short_description", MAX_SHORT_DESC_LEN), + ("long_description", MAX_LONG_DESC_LEN), + ): if keyname in content: value = content[keyname] if not isinstance(value, str): - raise SynapseError(400, "%r value is not a string" % (keyname,)) + raise SynapseError( + 400, + "%r value is not a string" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) + if len(value) > max_length: + raise SynapseError( + 400, + "Invalid %s parameter" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) profile[keyname] = value await self.store.update_group_profile(group_id, profile) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a19c556437..648fe91f53 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1472,10 +1472,22 @@ class AuthHandler(BaseHandler): # Remove the query parameters from the redirect URL to get a shorter version of # it. This is only to display a human-readable URL in the template, but not the # URL we redirect users to. - redirect_url_no_params = client_redirect_url.split("?")[0] + url_parts = urllib.parse.urlsplit(client_redirect_url) + + if url_parts.scheme == "https": + # for an https uri, just show the netloc (ie, the hostname. Specifically, + # the bit between "//" and "/"; this includes any potential + # "username:password@" prefix.) + display_url = url_parts.netloc + else: + # for other uris, strip the query-params (including the login token) and + # fragment. + display_url = urllib.parse.urlunsplit( + (url_parts.scheme, url_parts.netloc, url_parts.path, "", "") + ) html = self._sso_redirect_confirm_template.render( - display_url=redirect_url_no_params, + display_url=display_url, redirect_url=redirect_url, server_name=self._server_name, new_user=new_user, diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index bd35d1fb87..81ed44ac87 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from xml.etree import ElementTree as ET import attr @@ -49,7 +49,7 @@ class CasError(Exception): @attr.s(slots=True, frozen=True) class CasResponse: username = attr.ib(type=str) - attributes = attr.ib(type=Dict[str, Optional[str]]) + attributes = attr.ib(type=Dict[str, List[Optional[str]]]) class CasHandler: @@ -169,7 +169,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} + attributes = {} # type: Dict[str, List[Optional[str]]] for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -182,7 +182,7 @@ class CasHandler: tag = attribute.tag if "}" in tag: tag = tag.split("}")[1] - attributes[tag] = attribute.text + attributes.setdefault(tag, []).append(attribute.text) # Ensure a user was found. if user is None: @@ -303,29 +303,10 @@ class CasHandler: # Ensure that the attributes of the logged in user meet the required # attributes. - for required_attribute, required_value in self._cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in cas_response.attributes: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return - - # Also need to check value - if required_value is not None: - actual_value = cas_response.attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return + if not self._sso_handler.check_required_attributes( + request, cas_response.attributes, self._cas_required_attributes + ): + return # Call the mapper to register/login the user @@ -372,9 +353,10 @@ class CasHandler: if failures: raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + # Arbitrarily use the first attribute found. display_name = cas_response.attributes.get( - self._cas_displayname_attribute, None - ) + self._cas_displayname_attribute, [None] + )[0] return UserAttributes(localpart=localpart, display_name=display_name) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eddc7582d0..5581e06bb4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1354,8 +1354,6 @@ class FederationHandler(BaseHandler): await self._clean_room_for_join(room_id) - handled_events = set() - try: # Try the host we successfully got a response to /make_join/ # request first. @@ -1375,10 +1373,6 @@ class FederationHandler(BaseHandler): auth_chain = ret["auth_chain"] auth_chain.sort(key=lambda e: e.depth) - handled_events.update([s.event_id for s in state]) - handled_events.update([a.event_id for a in auth_chain]) - handled_events.add(event.event_id) - logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join state: %s", state) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 71008ec50d..3adc75fa4a 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -123,7 +123,6 @@ class OidcHandler: Args: request: the incoming request from the browser. """ - # The provider might redirect with an error. # In that case, just display it as-is. if b"error" in request.args: @@ -137,8 +136,12 @@ class OidcHandler: # either the provider misbehaving or Synapse being misconfigured. # The only exception of that is "access_denied", where the user # probably cancelled the login flow. In other cases, log those errors. - if error != "access_denied": - logger.error("Error from the OIDC provider: %s %s", error, description) + logger.log( + logging.INFO if error == "access_denied" else logging.ERROR, + "Received OIDC callback with error: %s %s", + error, + description, + ) self._sso_handler.render_error(request, error, description) return @@ -149,7 +152,7 @@ class OidcHandler: # Fetch the session cookie session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] if session is None: - logger.info("No session cookie found") + logger.info("Received OIDC callback, with no session cookie") self._sso_handler.render_error( request, "missing_session", "No session cookie found" ) @@ -169,7 +172,7 @@ class OidcHandler: # Check for the state query parameter if b"state" not in request.args: - logger.info("State parameter is missing") + logger.info("Received OIDC callback, with no state parameter") self._sso_handler.render_error( request, "invalid_request", "State parameter is missing" ) @@ -183,14 +186,16 @@ class OidcHandler: session, state ) except (MacaroonDeserializationException, ValueError) as e: - logger.exception("Invalid session") + logger.exception("Invalid session for OIDC callback") self._sso_handler.render_error(request, "invalid_session", str(e)) return except MacaroonInvalidSignatureException as e: - logger.exception("Could not verify session") + logger.exception("Could not verify session for OIDC callback") self._sso_handler.render_error(request, "mismatching_session", str(e)) return + logger.info("Received OIDC callback for IdP %s", session_data.idp_id) + oidc_provider = self._providers.get(session_data.idp_id) if not oidc_provider: logger.error("OIDC session uses unknown IdP %r", oidc_provider) @@ -565,6 +570,7 @@ class OidcProvider: Returns: UserInfo: an object representing the user. """ + logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() resp = await self._http_client.get_json( @@ -572,6 +578,8 @@ class OidcProvider: headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, ) + logger.debug("Retrieved user info from userinfo endpoint: %r", resp) + return UserInfo(resp) async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: @@ -600,17 +608,19 @@ class OidcProvider: claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) claim_options = {"iss": {"values": [metadata["issuer"]]}} + id_token = token["id_token"] + logger.debug("Attempting to decode JWT id_token %r", id_token) + # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() try: claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, @@ -620,13 +630,15 @@ class OidcProvider: logger.info("Reloading JWKS after decode error") jwk_set = await self.load_jwks(force=True) # try reloading the jwks claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, claims_params=claims_params, ) + logger.debug("Decoded id_token JWT %r; validating", claims) + claims.validate(leeway=120) # allows 2 min of clock skew return UserInfo(claims) @@ -726,19 +738,18 @@ class OidcProvider: """ # Exchange the code with the provider try: - logger.debug("Exchanging code") + logger.debug("Exchanging OAuth2 code for a token") token = await self._exchange_code(code) except OidcError as e: - logger.exception("Could not exchange code") + logger.exception("Could not exchange OAuth2 code") self._sso_handler.render_error(request, e.error, e.error_description) return - logger.debug("Successfully obtained OAuth2 access token") + logger.debug("Successfully obtained OAuth2 token data: %r", token) # Now that we have a token, get the userinfo, either by decoding the # `id_token` or by fetching the `userinfo_endpoint`. if self._uses_userinfo: - logger.debug("Fetching userinfo") try: userinfo = await self._fetch_userinfo(token) except Exception as e: @@ -746,7 +757,6 @@ class OidcProvider: self._sso_handler.render_error(request, "fetch_error", str(e)) return else: - logger.debug("Extracting userinfo from id_token") try: userinfo = await self._parse_id_token(token, nonce=session_data.nonce) except Exception as e: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 07b2187eb1..591a82f459 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,6 +38,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -424,17 +425,20 @@ class RoomCreationHandler(BaseHandler): # Copy over user power levels now as this will not be possible with >100PL users once # the room has been created - # Calculate the minimum power level needed to clone the room event_power_levels = power_levels.get("events", {}) - state_default = power_levels.get("state_default", 0) - ban = power_levels.get("ban") + state_default = power_levels.get("state_default", 50) + ban = power_levels.get("ban", 50) needed_power_level = max(state_default, ban, max(event_power_levels.values())) + # Get the user's current power level, this matches the logic in get_user_power_level, + # but without the entire state map. + user_power_levels = power_levels.setdefault("users", {}) + users_default = power_levels.get("users_default", 0) + current_power_level = user_power_levels.get(user_id, users_default) # Raise the requester's power level in the new room if necessary - current_power_level = power_levels["users"][user_id] if current_power_level < needed_power_level: - power_levels["users"][user_id] = needed_power_level + user_power_levels[user_id] = needed_power_level await self._send_events_for_new_room( requester, @@ -828,7 +832,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - # Always wait for room creation to progate before returning + # Always wait for room creation to propagate before returning await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), "events", @@ -1004,41 +1008,51 @@ class RoomCreationHandler(BaseHandler): class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.auth = hs.get_auth() self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state async def get_event_context( self, - user: UserID, + requester: Requester, room_id: str, event_id: str, limit: int, event_filter: Optional[Filter], + use_admin_priviledge: bool = False, ) -> Optional[JsonDict]: """Retrieves events, pagination tokens and state around a given event in a room. Args: - user + requester room_id event_id limit: The maximum number of events to return in total (excluding state). event_filter: the filter to apply to the events returned (excluding the target event_id) - + use_admin_priviledge: if `True`, return all events, regardless + of whether `user` has access to them. To be used **ONLY** + from the admin API. Returns: dict, or None if the event isn't found """ + user = requester.user + if use_admin_priviledge: + await assert_user_is_admin(self.auth, requester.user) + before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users - def filter_evts(events): - return filter_events_for_client( + async def filter_evts(events): + if use_admin_priviledge: + return events + return await filter_events_for_client( self.storage, user.to_string(), events, is_peeking=is_peeking ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index e88fd59749..78f130e152 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -23,7 +23,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.config.saml2_config import SamlAttributeRequirement from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string @@ -239,12 +238,10 @@ class SamlHandler(BaseHandler): # Ensure that the attributes of the logged in user meet the required # attributes. - for requirement in self._saml2_attribute_requirements: - if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._sso_handler.render_error( - request, "unauthorised", "You are not authorised to log in here." - ) - return + if not self._sso_handler.check_required_attributes( + request, saml2_auth.ava, self._saml2_attribute_requirements + ): + return # Call the mapper to register/login the user try: @@ -373,21 +370,6 @@ class SamlHandler(BaseHandler): del self._outstanding_requests_dict[reqid] -def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool: - values = ava.get(req.attribute, []) - for v in values: - if v == req.value: - return True - - logger.info( - "SAML2 attribute %s did not match required value '%s' (was '%s')", - req.attribute, - req.value, - values, - ) - return False - - DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index b450668f1c..a63fd52485 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -16,10 +16,12 @@ import abc import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, Iterable, + List, Mapping, Optional, Set, @@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError +from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -742,7 +745,11 @@ class SsoHandler: use_display_name: whether the user wants to use the suggested display name emails_to_use: emails that the user would like to use """ - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return # update the session with the user's choices session.chosen_localpart = localpart @@ -793,7 +800,12 @@ class SsoHandler: session_id, terms_version, ) - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + session.terms_accepted_version = terms_version # we're done; now we can register the user @@ -808,7 +820,11 @@ class SsoHandler: request: HTTP request session_id: ID of the username mapping session, extracted from a cookie """ - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return logger.info( "[session %s] Registering localpart %s", @@ -880,6 +896,41 @@ class SsoHandler: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + def check_required_attributes( + self, + request: SynapseRequest, + attributes: Mapping[str, List[Any]], + attribute_requirements: Iterable[SsoAttributeRequirement], + ) -> bool: + """ + Confirm that the required attributes were present in the SSO response. + + If all requirements are met, this will return True. + + If any requirement is not met, then the request will be finalized by + showing an error page to the user and False will be returned. + + Args: + request: The request to (potentially) respond to. + attributes: The attributes from the SSO IdP. + attribute_requirements: The requirements that attributes must meet. + + Returns: + True if all requirements are met, False if any attribute fails to + meet the requirement. + + """ + # Ensure that the attributes of the logged in user meet the required + # attributes. + for requirement in attribute_requirements: + if not _check_attribute_requirement(attributes, requirement): + self.render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return False + + return True + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie @@ -890,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: if not session_id: raise SynapseError(code=400, msg="missing session_id") return session_id.decode("ascii", errors="replace") + + +def _check_attribute_requirement( + attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement +) -> bool: + """Check if SSO attributes meet the proper requirements. + + Args: + attributes: A mapping of attributes to an iterable of one or more values. + requirement: The configured requirement to check. + + Returns: + True if the required attribute was found and had a proper value. + """ + if req.attribute not in attributes: + logger.info("SSO attribute missing: %s", req.attribute) + return False + + # If the requirement is None, the attribute existing is enough. + if req.value is None: + return True + + values = attributes[req.attribute] + if req.value in values: + return True + + logger.info( + "SSO attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + return False diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 4c06a117d3..113fd47134 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -323,12 +323,19 @@ class MatrixHostnameEndpoint: if port or _is_ip_literal(host): return [Server(host, port or 8448)] + logger.debug("Looking up SRV record for %s", host.decode(errors="replace")) server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host) if server_list: + logger.debug( + "Got %s from SRV lookup for %s", + ", ".join(map(str, server_list)), + host.decode(errors="replace"), + ) return server_list # No SRV records, so we fallback to host and 8448 + logger.debug("No SRV records for %s", host.decode(errors="replace")) return [Server(host, 8448)] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9018f9e20b..6317f22d3c 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -524,7 +524,7 @@ class RulesForRoom: class _Invalidation: # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules, # which means that it it is stored on the bulk_get_push_rules cache entry. In order - # to ensure that we don't accumulate lots of redunant callbacks on the cache entry, + # to ensure that we don't accumulate lots of redundant callbacks on the cache entry, # we need to ensure that two _Invalidation objects are "equal" if they refer to the # same `cache` and `room_id`. # diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 8a6dcff30d..d10201b6b3 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -34,6 +34,7 @@ from synapse.push.presentable_names import ( descriptor_from_member_events, name_from_member_event, ) +from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client @@ -110,6 +111,7 @@ class Mailer: self.sendmail = self.hs.get_sendmail() self.store = self.hs.get_datastore() + self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() @@ -217,7 +219,17 @@ class Mailer: push_actions: Iterable[Dict[str, Any]], reason: Dict[str, Any], ) -> None: - """Send email regarding a user's room notifications""" + """ + Send email regarding a user's room notifications + + Params: + app_id: The application receiving the notification. + user_id: The user receiving the notification. + email_address: The email address receiving the notification. + push_actions: All outstanding notifications. + reason: The notification that was ready and is the cause of an email + being sent. + """ rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) notif_events = await self.store.get_events( @@ -241,7 +253,7 @@ class Mailer: except StoreError: user_display_name = user_id - async def _fetch_room_state(room_id): + async def _fetch_room_state(room_id: str) -> None: room_state = await self.store.get_current_state_ids(room_id) state_by_room[room_id] = room_state @@ -255,7 +267,7 @@ class Mailer: rooms = [] for r in rooms_in_order: - roomvars = await self.get_room_vars( + roomvars = await self._get_room_vars( r, user_id, notifs_by_room[r], notif_events, state_by_room[r] ) rooms.append(roomvars) @@ -271,7 +283,7 @@ class Mailer: # Only one room has new stuff room_id = list(notifs_by_room.keys())[0] - summary_text = await self.make_summary_text_single_room( + summary_text = await self._make_summary_text_single_room( room_id, notifs_by_room[room_id], state_by_room[room_id], @@ -279,13 +291,13 @@ class Mailer: user_id, ) else: - summary_text = await self.make_summary_text( + summary_text = await self._make_summary_text( notifs_by_room, state_by_room, notif_events, reason ) template_vars = { "user_display_name": user_display_name, - "unsubscribe_link": self.make_unsubscribe_link( + "unsubscribe_link": self._make_unsubscribe_link( user_id, app_id, email_address ), "summary_text": summary_text, @@ -349,7 +361,7 @@ class Mailer: ) ) - async def get_room_vars( + async def _get_room_vars( self, room_id: str, user_id: str, @@ -357,6 +369,20 @@ class Mailer: notif_events: Dict[str, EventBase], room_state_ids: StateMap[str], ) -> Dict[str, Any]: + """ + Generate the variables for notifications on a per-room basis. + + Args: + room_id: The room ID + user_id: The user receiving the notification. + notifs: The outstanding push actions for this room. + notif_events: The events related to the above notifications. + room_state_ids: The event IDs of the current room state. + + Returns: + A dictionary to be added to the template context. + """ + # Check if one of the notifs is an invite event for the user. is_invite = False for n in notifs: @@ -373,12 +399,12 @@ class Mailer: "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], "invite": is_invite, - "link": self.make_room_link(room_id), + "link": self._make_room_link(room_id), } # type: Dict[str, Any] if not is_invite: for n in notifs: - notifvars = await self.get_notif_vars( + notifvars = await self._get_notif_vars( n, user_id, notif_events[n["event_id"]], room_state_ids ) @@ -405,13 +431,26 @@ class Mailer: return room_vars - async def get_notif_vars( + async def _get_notif_vars( self, notif: Dict[str, Any], user_id: str, notif_event: EventBase, room_state_ids: StateMap[str], ) -> Dict[str, Any]: + """ + Generate the variables for a single notification. + + Args: + notif: The outstanding notification for this room. + user_id: The user receiving the notification. + notif_event: The event related to the above notification. + room_state_ids: The event IDs of the current room state. + + Returns: + A dictionary to be added to the template context. + """ + results = await self.store.get_events_around( notif["room_id"], notif["event_id"], @@ -420,7 +459,7 @@ class Mailer: ) ret = { - "link": self.make_notif_link(notif), + "link": self._make_notif_link(notif), "ts": notif["received_ts"], "messages": [], } @@ -431,22 +470,51 @@ class Mailer: the_events.append(notif_event) for event in the_events: - messagevars = await self.get_message_vars(notif, event, room_state_ids) + messagevars = await self._get_message_vars(notif, event, room_state_ids) if messagevars is not None: ret["messages"].append(messagevars) return ret - async def get_message_vars( + async def _get_message_vars( self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str] ) -> Optional[Dict[str, Any]]: + """ + Generate the variables for a single event, if possible. + + Args: + notif: The outstanding notification for this room. + event: The event under consideration. + room_state_ids: The event IDs of the current room state. + + Returns: + A dictionary to be added to the template context, or None if the + event cannot be processed. + """ if event.type != EventTypes.Message and event.type != EventTypes.Encrypted: return None - sender_state_event_id = room_state_ids[("m.room.member", event.sender)] - sender_state_event = await self.store.get_event(sender_state_event_id) - sender_name = name_from_member_event(sender_state_event) - sender_avatar_url = sender_state_event.content.get("avatar_url") + # Get the sender's name and avatar from the room state. + type_state_key = ("m.room.member", event.sender) + sender_state_event_id = room_state_ids.get(type_state_key) + if sender_state_event_id: + sender_state_event = await self.store.get_event( + sender_state_event_id + ) # type: Optional[EventBase] + else: + # Attempt to check the historical state for the room. + historical_state = await self.state_store.get_state_for_event( + event.event_id, StateFilter.from_types((type_state_key,)) + ) + sender_state_event = historical_state.get(type_state_key) + + if sender_state_event: + sender_name = name_from_member_event(sender_state_event) + sender_avatar_url = sender_state_event.content.get("avatar_url") + else: + # No state could be found, fallback to the MXID. + sender_name = event.sender + sender_avatar_url = None # 'hash' for deterministically picking default images: use # sender_hash % the number of default images to choose from @@ -471,18 +539,25 @@ class Mailer: ret["msgtype"] = msgtype if msgtype == "m.text": - self.add_text_message_vars(ret, event) + self._add_text_message_vars(ret, event) elif msgtype == "m.image": - self.add_image_message_vars(ret, event) + self._add_image_message_vars(ret, event) if "body" in event.content: ret["body_text_plain"] = event.content["body"] return ret - def add_text_message_vars( + def _add_text_message_vars( self, messagevars: Dict[str, Any], event: EventBase ) -> None: + """ + Potentially add a sanitised message body to the message variables. + + Args: + messagevars: The template context to be modified. + event: The event under consideration. + """ msgformat = event.content.get("format") messagevars["format"] = msgformat @@ -495,16 +570,20 @@ class Mailer: elif body: messagevars["body_text_html"] = safe_text(body) - def add_image_message_vars( + def _add_image_message_vars( self, messagevars: Dict[str, Any], event: EventBase ) -> None: """ Potentially add an image URL to the message variables. + + Args: + messagevars: The template context to be modified. + event: The event under consideration. """ if "url" in event.content: messagevars["image_url"] = event.content["url"] - async def make_summary_text_single_room( + async def _make_summary_text_single_room( self, room_id: str, notifs: List[Dict[str, Any]], @@ -517,7 +596,7 @@ class Mailer: Args: room_id: The ID of the room. - notifs: The notifications for this room. + notifs: The push actions for this room. room_state_ids: The state map for the room. notif_events: A map of event ID -> notification event. user_id: The user receiving the notification. @@ -600,11 +679,11 @@ class Mailer: "app": self.app_name, } - return await self.make_summary_text_from_member_events( + return await self._make_summary_text_from_member_events( room_id, notifs, room_state_ids, notif_events ) - async def make_summary_text( + async def _make_summary_text( self, notifs_by_room: Dict[str, List[Dict[str, Any]]], room_state_ids: Dict[str, StateMap[str]], @@ -615,7 +694,7 @@ class Mailer: Make a summary text for the email when multiple rooms have notifications. Args: - notifs_by_room: A map of room ID to the notifications for that room. + notifs_by_room: A map of room ID to the push actions for that room. room_state_ids: A map of room ID to the state map for that room. notif_events: A map of event ID -> notification event. reason: The reason this notification is being sent. @@ -632,11 +711,11 @@ class Mailer: } room_id = reason["room_id"] - return await self.make_summary_text_from_member_events( + return await self._make_summary_text_from_member_events( room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events ) - async def make_summary_text_from_member_events( + async def _make_summary_text_from_member_events( self, room_id: str, notifs: List[Dict[str, Any]], @@ -648,7 +727,7 @@ class Mailer: Args: room_id: The ID of the room. - notifs: The notifications for this room. + notifs: The push actions for this room. room_state_ids: The state map for the room. notif_events: A map of event ID -> notification event. @@ -657,14 +736,45 @@ class Mailer: """ # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = {notif_events[n["event_id"]].sender for n in notifs} - member_events = await self.store.get_events( - [room_state_ids[("m.room.member", s)] for s in sender_ids] - ) + # Find the latest event ID for each sender, note that the notifications + # are already in descending received_ts. + sender_ids = {} + for n in notifs: + sender = notif_events[n["event_id"]].sender + if sender not in sender_ids: + sender_ids[sender] = n["event_id"] + + # Get the actual member events (in order to calculate a pretty name for + # the room). + member_event_ids = [] + member_events = {} + for sender_id, event_id in sender_ids.items(): + type_state_key = ("m.room.member", sender_id) + sender_state_event_id = room_state_ids.get(type_state_key) + if sender_state_event_id: + member_event_ids.append(sender_state_event_id) + else: + # Attempt to check the historical state for the room. + historical_state = await self.state_store.get_state_for_event( + event_id, StateFilter.from_types((type_state_key,)) + ) + sender_state_event = historical_state.get(type_state_key) + if sender_state_event: + member_events[event_id] = sender_state_event + member_events.update(await self.store.get_events(member_event_ids)) + + if not member_events: + # No member events were found! Maybe the room is empty? + # Fallback to the room ID (note that if there was a room name this + # would already have been used previously). + return self.email_subjects.messages_in_room % { + "room": room_id, + "app": self.app_name, + } # There was a single sender. - if len(sender_ids) == 1: + if len(member_events) == 1: return self.email_subjects.messages_from_person % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, @@ -676,7 +786,16 @@ class Mailer: "app": self.app_name, } - def make_room_link(self, room_id: str) -> str: + def _make_room_link(self, room_id: str) -> str: + """ + Generate a link to open a room in the web client. + + Args: + room_id: The room ID to generate a link to. + + Returns: + A link to open a room in the web client. + """ if self.hs.config.email_riot_base_url: base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) elif self.app_name == "Vector": @@ -686,7 +805,16 @@ class Mailer: base_url = "https://matrix.to/#" return "%s/%s" % (base_url, room_id) - def make_notif_link(self, notif: Dict[str, str]) -> str: + def _make_notif_link(self, notif: Dict[str, str]) -> str: + """ + Generate a link to open an event in the web client. + + Args: + notif: The notification to generate a link for. + + Returns: + A link to open the notification in the web client. + """ if self.hs.config.email_riot_base_url: return "%s/#/room/%s/%s" % ( self.hs.config.email_riot_base_url, @@ -702,9 +830,20 @@ class Mailer: else: return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) - def make_unsubscribe_link( + def _make_unsubscribe_link( self, user_id: str, app_id: str, email_address: str ) -> str: + """ + Generate a link to unsubscribe from email notifications. + + Args: + user_id: The user receiving the notification. + app_id: The application receiving the notification. + email_address: The email address receiving the notification. + + Returns: + A link to unsubscribe from email notifications. + """ params = { "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "app_id": app_id, diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index bfd46a3730..8a2b73b75e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -86,8 +86,12 @@ REQUIREMENTS = [ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], - # we use execute_values with the fetch param, which arrived in psycopg 2.8. - "postgres": ["psycopg2>=2.8"], + "postgres": [ + # we use execute_values with the fetch param, which arrived in psycopg 2.8. + "psycopg2>=2.8 ; platform_python_implementation != 'PyPy'", + "psycopg2cffi>=2.8 ; platform_python_implementation == 'PyPy'", + "psycopg2cffi-compat==1.1 ; platform_python_implementation == 'PyPy'", + ], # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index fdd087683b..89f8af0f36 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -15,8 +15,9 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast +import attr import txredisapi from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -42,6 +43,24 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +T = TypeVar("T") +V = TypeVar("V") + + +@attr.s +class ConstantProperty(Generic[T, V]): + """A descriptor that returns the given constant, ignoring attempts to set + it. + """ + + constant = attr.ib() # type: V + + def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V: + return self.constant + + def __set__(self, obj: Optional[T], value: V): + pass + class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): """Connection to redis subscribed to replication stream. @@ -195,6 +214,10 @@ class SynapseRedisFactory(txredisapi.RedisFactory): we detect dead connections. """ + # We want to *always* retry connecting, txredisapi will stop if there is a + # failure during certain operations, e.g. during AUTH. + continueTrying = cast(bool, ConstantProperty(True)) + def __init__( self, hs: "HomeServer", @@ -243,7 +266,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """ maxDelay = 5 - continueTrying = True protocol = RedisSubscriber def __init__( diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css index 46b309ea4e..338214f5d0 100644 --- a/synapse/res/templates/sso.css +++ b/synapse/res/templates/sso.css @@ -1,16 +1,26 @@ -body { +body, input, select, textarea { font-family: "Inter", "Helvetica", "Arial", sans-serif; font-size: 14px; color: #17191C; } -header { +header, footer { max-width: 480px; width: 100%; margin: 24px auto; text-align: center; } +@media screen and (min-width: 800px) { + header { + margin-top: 90px; + } +} + +header { + min-height: 60px; +} + header p { color: #737D8C; line-height: 24px; @@ -20,6 +30,10 @@ h1 { font-size: 24px; } +a { + color: #418DED; +} + .error_page h1 { color: #FE2928; } @@ -47,6 +61,9 @@ main { .primary-button { border: none; + -webkit-appearance: none; + -moz-appearance: none; + appearance: none; text-decoration: none; padding: 12px; color: white; @@ -63,8 +80,17 @@ main { .profile { display: flex; + flex-direction: column; + align-items: center; justify-content: center; - margin: 24px 0; + margin: 24px; + padding: 13px; + border: 1px solid #E9ECF1; + border-radius: 4px; +} + +.profile.with-avatar { + margin-top: 42px; /* (36px / 2) + 24px*/ } .profile .avatar { @@ -72,17 +98,32 @@ main { height: 36px; border-radius: 100%; display: block; - margin-right: 8px; + margin-top: -32px; + margin-bottom: 8px; } .profile .display-name { font-weight: bold; margin-bottom: 4px; + font-size: 15px; + line-height: 18px; } .profile .user-id { color: #737D8C; + font-size: 12px; + line-height: 12px; } -.profile .display-name, .profile .user-id { - line-height: 18px; +footer { + margin-top: 80px; } + +footer svg { + display: block; + width: 46px; + margin: 0px auto 12px auto; +} + +footer p { + color: #737D8C; +} \ No newline at end of file diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html index 50a0979c2f..c3e4deed93 100644 --- a/synapse/res/templates/sso_account_deactivated.html +++ b/synapse/res/templates/sso_account_deactivated.html @@ -20,5 +20,6 @@ administrator. </p> </header> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index 36850a2d6a..f4fdc40b22 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -1,12 +1,29 @@ <!DOCTYPE html> <html lang="en"> <head> - <title>Synapse Login</title> + <title>Create your account</title> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, user-scalable=no"> + <script type="text/javascript"> + let wasKeyboard = false; + document.addEventListener("mousedown", function() { wasKeyboard = false; }); + document.addEventListener("keydown", function() { wasKeyboard = true; }); + document.addEventListener("focusin", function() { + if (wasKeyboard) { + document.body.classList.add("keyboard-focus"); + } else { + document.body.classList.remove("keyboard-focus"); + } + }); + </script> <style type="text/css"> {% include "sso.css" without context %} + body.keyboard-focus :focus, body.keyboard-focus .username_input:focus-within { + outline: 3px solid #17191C; + outline-offset: 4px; + } + .username_input { display: flex; border: 2px solid #418DED; @@ -33,11 +50,12 @@ .username_input label { position: absolute; - top: -8px; + top: -5px; left: 14px; - font-size: 80%; + font-size: 10px; + line-height: 10px; background: white; - padding: 2px; + padding: 0 2px; } .username_input input { @@ -47,6 +65,13 @@ border: none; } + /* only clear the outline if we know it will be shown on the parent div using :focus-within */ + @supports selector(:focus-within) { + .username_input input { + outline: none !important; + } + } + .username_input div { color: #8D99A5; } @@ -65,6 +90,7 @@ .idp-pick-details .idp-detail { border-top: 1px solid #E9ECF1; padding: 12px; + display: block; } .idp-pick-details .check-row { display: flex; @@ -117,43 +143,44 @@ </div> <output for="username_input" id="field-username-output"></output> <input type="submit" value="Continue" class="primary-button"> - {% if user_attributes %} + {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %} <section class="idp-pick-details"> <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2> {% if user_attributes.avatar_url %} - <div class="idp-detail idp-avatar"> + <label class="idp-detail idp-avatar" for="idp-avatar"> <div class="check-row"> - <label for="idp-avatar" class="name">Avatar</label> - <label for="idp-avatar" class="use">Use</label> + <span class="name">Avatar</span> + <span class="use">Use</span> <input type="checkbox" name="use_avatar" id="idp-avatar" value="true" checked> </div> <img src="{{ user_attributes.avatar_url }}" class="avatar" /> - </div> + </label> {% endif %} {% if user_attributes.display_name %} - <div class="idp-detail"> + <label class="idp-detail" for="idp-displayname"> <div class="check-row"> - <label for="idp-displayname" class="name">Display name</label> - <label for="idp-displayname" class="use">Use</label> + <span class="name">Display name</span> + <span class="use">Use</span> <input type="checkbox" name="use_display_name" id="idp-displayname" value="true" checked> </div> <p class="idp-value">{{ user_attributes.display_name }}</p> - </div> + </label> {% endif %} {% for email in user_attributes.emails %} - <div class="idp-detail"> + <label class="idp-detail" for="idp-email{{ loop.index }}"> <div class="check-row"> - <label for="idp-email{{ loop.index }}" class="name">E-mail</label> - <label for="idp-email{{ loop.index }}" class="use">Use</label> + <span class="name">E-mail</span> + <span class="use">Use</span> <input type="checkbox" name="use_email" id="idp-email{{ loop.index }}" value="{{ email }}" checked> </div> <p class="idp-value">{{ email }}</p> - </div> + </label> {% endfor %} </section> {% endif %} </form> </main> + {% include "sso_footer.html" without context %} <script type="text/javascript"> {% include "sso_auth_account_details.js" without context %} </script> diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html index c9bd4bef20..da579ffe69 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -21,5 +21,6 @@ the Identity Provider as when you log into your account. </p> </header> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html index 2099c2f1f8..f9d0456f0a 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html @@ -2,7 +2,7 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>Authentication</title> + <title>Confirm it's you</title> <meta name="viewport" content="width=device-width, user-scalable=no"> <style type="text/css"> {% include "sso.css" without context %} @@ -24,5 +24,6 @@ Continue with {{ idp.idp_name }} </a> </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html index 3b975d7219..1ed3967e87 100644 --- a/synapse/res/templates/sso_auth_success.html +++ b/synapse/res/templates/sso_auth_success.html @@ -23,5 +23,6 @@ application. </p> </header> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index b223ca0f56..472309c350 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -38,6 +38,7 @@ <p>{{ error }}</p> </div> </header> + {% include "sso_footer.html" without context %} <script type="text/javascript"> // Error handling to support Auth0 errors that we might get through a GET request diff --git a/synapse/res/templates/sso_footer.html b/synapse/res/templates/sso_footer.html new file mode 100644 index 0000000000..588a3d508d --- /dev/null +++ b/synapse/res/templates/sso_footer.html @@ -0,0 +1,19 @@ +<footer> + <svg role="img" aria-label="[Matrix logo]" viewBox="0 0 200 85" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"> + <g id="parent" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd"> + <g id="child" transform="translate(-122.000000, -6.000000)" fill="#000000" fill-rule="nonzero"> + <g id="matrix-logo" transform="translate(122.000000, 6.000000)"> + <polygon id="left-bracket" points="2.24708861 1.93811009 2.24708861 82.7268844 8.10278481 82.7268844 8.10278481 84.6652459 0 84.6652459 0 0 8.10278481 0 8.10278481 1.93811009"></polygon> + <path d="M24.8073418,27.5493174 L24.8073418,31.6376991 L24.924557,31.6376991 C26.0227848,30.0814294 27.3455696,28.8730642 28.8951899,28.0163743 C30.4437975,27.1611927 32.2189873,26.7318422 34.218481,26.7318422 C36.1394937,26.7318422 37.8946835,27.102622 39.4825316,27.8416679 C41.0708861,28.5819706 42.276962,29.8856073 43.1005063,31.7548404 C44.0017722,30.431345 45.2270886,29.2629486 46.7767089,28.2506569 C48.3253165,27.2388679 50.158481,26.7318422 52.2764557,26.7318422 C53.8843038,26.7318422 55.3736709,26.9269101 56.7473418,27.3162917 C58.1189873,27.7056734 59.295443,28.3285835 60.2759494,29.185022 C61.255443,30.0422147 62.02,31.1615927 62.5701266,32.5426532 C63.1187342,33.9262275 63.3936709,35.5898349 63.3936709,37.5372459 L63.3936709,57.7443688 L55.0410127,57.7441174 L55.0410127,40.6319376 C55.0410127,39.6201486 55.0020253,38.6661761 54.9232911,37.7700202 C54.8440506,36.8751211 54.6293671,36.0968606 54.2764557,35.4339817 C53.9232911,34.772611 53.403038,34.2464807 52.7177215,33.8568477 C52.0313924,33.4689743 51.0997468,33.2731523 49.9235443,33.2731523 C48.7473418,33.2731523 47.7962025,33.4983853 47.0706329,33.944578 C46.344557,34.393033 45.7764557,34.9774826 45.3650633,35.6969211 C44.9534177,36.4181193 44.6787342,37.2353431 44.5417722,38.150855 C44.4037975,39.0653615 44.3356962,39.9904257 44.3356962,40.9247908 L44.3356962,57.7443688 L35.9835443,57.7443688 L35.9835443,40.8079009 C35.9835443,39.9124991 35.963038,39.0263982 35.9253165,38.150855 C35.8853165,37.2743064 35.7192405,36.4666349 35.424557,35.7263321 C35.1303797,34.9872862 34.64,34.393033 33.9539241,33.944578 C33.2675949,33.4983853 32.2579747,33.2731523 30.9248101,33.2731523 C30.5321519,33.2731523 30.0126582,33.3608826 29.3663291,33.5365945 C28.7192405,33.7118037 28.0913924,34.0433688 27.4840506,34.5292789 C26.875443,35.0164459 26.3564557,35.7172826 25.9250633,36.6315376 C25.4934177,37.5470495 25.2779747,38.7436 25.2779747,40.2229486 L25.2779747,57.7441174 L16.9260759,57.7443688 L16.9260759,27.5493174 L24.8073418,27.5493174 Z" id="m"></path> + <path d="M68.7455696,31.9886202 C69.6075949,30.7033339 70.7060759,29.672189 72.0397468,28.8926716 C73.3724051,28.1141596 74.8716456,27.5596239 76.5387342,27.2283101 C78.2050633,26.8977505 79.8817722,26.7315908 81.5678481,26.7315908 C83.0974684,26.7315908 84.6458228,26.8391798 86.2144304,27.0525982 C87.7827848,27.2675248 89.2144304,27.6865688 90.5086076,28.3087248 C91.8025316,28.9313835 92.8610127,29.7983798 93.6848101,30.9074514 C94.5083544,32.0170257 94.92,33.4870734 94.92,35.3173431 L94.92,51.026844 C94.92,52.3913138 94.998481,53.6941963 95.1556962,54.9400165 C95.3113924,56.1865908 95.5863291,57.120956 95.9787342,57.7436147 L87.5091139,57.7436147 C87.3518987,57.276055 87.2240506,56.7996972 87.1265823,56.3125303 C87.0278481,55.8266202 86.9592405,55.3301523 86.9207595,54.8236294 C85.5873418,56.1865908 84.0182278,57.1405633 82.2156962,57.6857982 C80.4113924,58.2295248 78.5683544,58.503022 76.6860759,58.503022 C75.2346835,58.503022 73.8817722,58.3275615 72.6270886,57.9776459 C71.3718987,57.6269761 70.2744304,57.082244 69.3334177,56.3411872 C68.3921519,55.602644 67.656962,54.6680275 67.1275949,53.5390972 C66.5982278,52.410167 66.3331646,51.065556 66.3331646,49.5087835 C66.3331646,47.7961578 66.6367089,46.384178 67.2455696,45.2756092 C67.8529114,44.1652807 68.6367089,43.2799339 69.5987342,42.6173064 C70.5589873,41.9556844 71.6567089,41.4592165 72.8924051,41.1284055 C74.1273418,40.7978459 75.3721519,40.5356606 76.6270886,40.3398385 C77.8820253,40.1457761 79.116962,39.9896716 80.3329114,39.873033 C81.5483544,39.7558917 82.6270886,39.5804312 83.5681013,39.3469028 C84.5093671,39.1133743 85.2536709,38.7732624 85.8032911,38.3250587 C86.3513924,37.8773578 86.6063291,37.2252881 86.5678481,36.3680954 C86.5678481,35.4731963 86.4210127,34.7620532 86.1268354,34.2366771 C85.8329114,33.7113009 85.4405063,33.3018092 84.9506329,33.0099615 C84.4602532,32.7181138 83.8916456,32.5232972 83.2450633,32.4255119 C82.5977215,32.3294862 81.9010127,32.2797138 81.156962,32.2797138 C79.5098734,32.2797138 78.2159494,32.6303835 77.2746835,33.3312202 C76.3339241,34.0320569 75.7837975,35.2007046 75.6275949,36.8354037 L67.275443,36.8354037 C67.3924051,34.8892495 67.8817722,33.2726495 68.7455696,31.9886202 Z M85.2440506,43.6984752 C84.7149367,43.873433 84.1460759,44.0189798 83.5387342,44.1361211 C82.9306329,44.253011 82.2936709,44.350545 81.6270886,44.4279688 C80.96,44.5066495 80.2934177,44.6034294 79.6273418,44.7203193 C78.9994937,44.8362037 78.3820253,44.9933138 77.7749367,45.1871248 C77.1663291,45.3829468 76.636962,45.6451321 76.1865823,45.9759431 C75.7349367,46.3070055 75.3724051,46.7263009 75.0979747,47.2313156 C74.8232911,47.7375872 74.6863291,48.380356 74.6863291,49.1588679 C74.6863291,49.8979138 74.8232911,50.5218294 75.0979747,51.026844 C75.3724051,51.5338697 75.7455696,51.9328037 76.2159494,52.2246514 C76.6863291,52.5164991 77.2349367,52.7213706 77.8632911,52.8375064 C78.4898734,52.9546477 79.136962,53.012967 79.8037975,53.012967 C81.4506329,53.012967 82.724557,52.740978 83.6273418,52.1952404 C84.5288608,51.6507596 85.1949367,50.9981872 85.6270886,50.2382771 C86.0579747,49.4793725 86.323038,48.7119211 86.4212658,47.9321523 C86.518481,47.1536404 86.5681013,46.5304789 86.5681013,46.063422 L86.5681013,42.9677248 C86.2146835,43.2799339 85.7736709,43.5230147 85.2440506,43.6984752 Z" id="a"></path> + <path d="M116.917975,27.5493174 L116.917975,33.0976917 L110.801266,33.0976917 L110.801266,48.0492936 C110.801266,49.4502128 111.036203,50.3850807 111.507089,50.8518862 C111.976962,51.3191945 112.918734,51.5527229 114.33038,51.5527229 C114.801013,51.5527229 115.251392,51.5336183 115.683038,51.4944037 C116.114177,51.4561945 116.526076,51.3968697 116.917975,51.3194459 L116.917975,57.7438661 C116.212152,57.860756 115.427595,57.9381798 114.565316,57.9778972 C113.702785,58.0153523 112.859747,58.0357138 112.036203,58.0357138 C110.742278,58.0357138 109.516456,57.9477321 108.36,57.7722716 C107.202785,57.5975651 106.183544,57.2577046 105.301519,56.7509303 C104.418987,56.2454128 103.722785,55.5242147 103.213418,54.5898495 C102.703038,53.6562385 102.448608,52.4292716 102.448608,50.9099541 L102.448608,33.0976917 L97.3903797,33.0976917 L97.3903797,27.5493174 L102.448608,27.5493174 L102.448608,18.4967596 L110.801013,18.4967596 L110.801013,27.5493174 L116.917975,27.5493174 Z" id="t"></path> + <path d="M128.857975,27.5493174 L128.857975,33.1565138 L128.975696,33.1565138 C129.367089,32.2213945 129.896203,31.3559064 130.563544,30.557033 C131.23038,29.7596679 131.99443,29.0776844 132.857215,28.5130936 C133.719241,27.9495083 134.641266,27.5113596 135.622532,27.1988991 C136.601772,26.8879468 137.622025,26.7315908 138.681013,26.7315908 C139.229873,26.7315908 139.836962,26.8296275 140.504304,27.0239413 L140.504304,34.7336477 C140.111646,34.6552183 139.641013,34.586844 139.092658,34.5290275 C138.543291,34.4704569 138.014177,34.4410459 137.504304,34.4410459 C135.974937,34.4410459 134.681013,34.6949358 133.622785,35.2004532 C132.564051,35.7067248 131.711392,36.397255 131.064051,37.2735523 C130.417215,38.1501009 129.955443,39.1714422 129.681266,40.3398385 C129.407089,41.5074807 129.269873,42.7736624 129.269873,44.1361211 L129.269873,57.7438661 L120.917722,57.7438661 L120.917722,27.5493174 L128.857975,27.5493174 Z" id="r"></path> + <path d="M144.033165,22.8767376 L144.033165,16.0435798 L152.386076,16.0435798 L152.386076,22.8767376 L144.033165,22.8767376 Z M152.386076,27.5493174 L152.386076,57.7438661 L144.033165,57.7438661 L144.033165,27.5493174 L152.386076,27.5493174 Z" id="i"></path> + <polygon id="x" points="156.738228 27.5493174 166.266582 27.5493174 171.619494 35.4337303 176.913418 27.5493174 186.147848 27.5493174 176.148861 41.6831927 187.383544 57.7441174 177.85443 57.7441174 171.501772 48.2245028 165.148861 57.7441174 155.797468 57.7441174 166.737468 41.8589046"></polygon> + <polygon id="right-bracket" points="197.580759 82.7268844 197.580759 1.93811009 191.725063 1.93811009 191.725063 0 199.828354 0 199.828354 84.6652459 191.725063 84.6652459 191.725063 82.7268844"></polygon> + </g> + </g> + </g> + </svg> + <p>An open network for secure, decentralized communication.<br>© 2021 The Matrix.org Foundation C.I.C.</p> +</footer> \ No newline at end of file diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index 62a640dad2..53b82db84e 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -2,30 +2,60 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <link rel="stylesheet" href="/_matrix/static/client/login/style.css"> - <title>{{ server_name }} Login</title> + <title>Choose identity provider</title> + <style type="text/css"> + {% include "sso.css" without context %} + + .providers { + list-style: none; + padding: 0; + } + + .providers li { + margin: 12px; + } + + .providers a { + display: block; + border-radius: 4px; + border: 1px solid #17191C; + padding: 8px; + text-align: center; + text-decoration: none; + color: #17191C; + display: flex; + align-items: center; + font-weight: bold; + } + + .providers a img { + width: 24px; + height: 24px; + } + .providers a span { + flex: 1; + } + </style> </head> <body> - <div id="container"> - <h1 id="title">{{ server_name }} Login</h1> - <div class="login_flow"> - <p>Choose one of the following identity providers:</p> - <form> - <input type="hidden" name="redirectUrl" value="{{ redirect_url }}"> - <ul class="radiobuttons"> -{% for p in providers %} - <li> - <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}"> - <label for="prov{{ loop.index }}">{{ p.idp_name }}</label> -{% if p.idp_icon %} + <header> + <h1>Log in to {{ server_name }} </h1> + <p>Choose an identity provider to log in</p> + </header> + <main> + <ul class="providers"> + {% for p in providers %} + <li> + <a href="pick_idp?idp={{ p.idp_id }}&redirectUrl={{ redirect_url | urlencode }}"> + {% if p.idp_icon %} <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/> -{% endif %} - </li> -{% endfor %} - </ul> - <input type="submit" class="button button--full-width" id="button-submit" value="Submit"> - </form> - </div> - </div> + {% endif %} + <span>{{ p.idp_name }}</span> + </a> + </li> + {% endfor %} + </ul> + </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html index 8c33787c54..68c8b9f33a 100644 --- a/synapse/res/templates/sso_new_user_consent.html +++ b/synapse/res/templates/sso_new_user_consent.html @@ -2,7 +2,7 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>SSO redirect confirmation</title> + <title>Agree to terms and conditions</title> <meta name="viewport" content="width=device-width, user-scalable=no"> <style type="text/css"> {% include "sso.css" without context %} @@ -18,22 +18,15 @@ <p>Agree to the terms to create your account.</p> </header> <main> - <!-- {% if user_profile.avatar_url and user_profile.display_name %} --> - <div class="profile"> - <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> - <div class="profile-details"> - <div class="display-name">{{ user_profile.display_name }}</div> - <div class="user-id">{{ user_id }}</div> - </div> - </div> - <!-- {% endif %} --> + {% include "sso_partial_profile.html" %} <form method="post" action="{{my_url}}" id="consent_form"> <p> <input id="accepted_version" type="checkbox" name="accepted_version" value="{{ consent_version }}" required> - <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank">terms and conditions</a>.</label> + <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank" rel="noopener">terms and conditions</a>.</label> </p> <input type="submit" class="primary-button" value="Continue"/> </form> </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_partial_profile.html b/synapse/res/templates/sso_partial_profile.html new file mode 100644 index 0000000000..c9c76c455e --- /dev/null +++ b/synapse/res/templates/sso_partial_profile.html @@ -0,0 +1,19 @@ +{# html fragment to be included in SSO pages, to show the user's profile #} + +<div class="profile{% if user_profile.avatar_url %} with-avatar{% endif %}"> + {% if user_profile.avatar_url %} + <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> + {% endif %} + {# users that signed up with SSO will have a display_name of some sort; + however that is not the case for users who signed up via other + methods, so we need to handle that. + #} + {% if user_profile.display_name %} + <div class="display-name">{{ user_profile.display_name }}</div> + {% else %} + {# split the userid on ':', take the part before the first ':', + and then remove the leading '@'. #} + <div class="display-name">{{ user_id.split(":")[0][1:] }}</div> + {% endif %} + <div class="user-id">{{ user_id }}</div> +</div> diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html index d1328a6969..1b01471ac8 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -2,35 +2,39 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>SSO redirect confirmation</title> + <title>Continue to your account</title> <meta name="viewport" content="width=device-width, user-scalable=no"> <style type="text/css"> {% include "sso.css" without context %} + + .confirm-trust { + margin: 34px 0; + color: #8D99A5; + } + .confirm-trust strong { + color: #17191C; + } + + .confirm-trust::before { + content: ""; + background-image: url(''); + background-repeat: no-repeat; + width: 24px; + height: 24px; + display: block; + float: left; + } </style> </head> <body> <header> - {% if new_user %} - <h1>Your account is now ready</h1> - <p>You've made your account on {{ server_name }}.</p> - {% else %} - <h1>Log in</h1> - {% endif %} - <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p> + <h1>Continue to your account</h1> </header> <main> - {% if user_profile.avatar_url %} - <div class="profile"> - <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> - <div class="profile-details"> - {% if user_profile.display_name %} - <div class="display-name">{{ user_profile.display_name }}</div> - {% endif %} - <div class="user-id">{{ user_id }}</div> - </div> - </div> - {% endif %} + {% include "sso_partial_profile.html" %} + <p class="confirm-trust">Continuing will grant <strong>{{ display_url }}</strong> access to your account.</p> <a href="{{ redirect_url }}" class="primary-button">Continue</a> </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index f5c5d164f9..8457db1e22 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -42,6 +42,7 @@ from synapse.rest.admin.rooms import ( JoinRoomAliasServlet, ListRoomRestServlet, MakeRoomAdminRestServlet, + RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, RoomStateRestServlet, @@ -238,6 +239,7 @@ def register_servlets(hs, http_server): MakeRoomAdminRestServlet(hs).register(http_server) ShadowBanRestServlet(hs).register(http_server) ForwardExtremitiesRestServlet(hs).register(http_server) + RoomEventContextServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 3e57e6a4d0..491f9ca095 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -15,9 +15,11 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple +from urllib import parse as urlparse from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.filtering import Filter from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -33,6 +35,7 @@ from synapse.rest.admin._base import ( ) from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -605,3 +608,65 @@ class ForwardExtremitiesRestServlet(RestServlet): extremities = await self.store.get_forward_extremities_for_room(room_id) return 200, {"count": len(extremities), "results": extremities} + + +class RoomEventContextServlet(RestServlet): + """ + Provide the context for an event. + This API is designed to be used when system administrators wish to look at + an abuse report and understand what happened during and immediately prior + to this event. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$") + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.room_context_handler = hs.get_room_context_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + await assert_user_is_admin(self.auth, requester.user) + + limit = parse_integer(request, "limit", default=10) + + # picking the API shape for symmetry with /messages + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] + else: + event_filter = None + + results = await self.room_context_handler.get_event_context( + requester, + room_id, + event_id, + limit, + event_filter, + use_admin_priviledge=True, + ) + + if not results: + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + time_now = self.clock.time_msec() + results["events_before"] = await self._event_serializer.serialize_events( + results["events_before"], time_now + ) + results["event"] = await self._event_serializer.serialize_event( + results["event"], time_now + ) + results["events_after"] = await self._event_serializer.serialize_events( + results["events_after"], time_now + ) + results["state"] = await self._event_serializer.serialize_events( + results["state"], time_now + ) + + return 200, results diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 68c3c64a0d..9350c704b9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -752,7 +752,7 @@ class PushersRestServlet(RestServlet): Returns: pushers: Dictionary containing pushers information. - total: Number of pushers in dictonary `pushers`. + total: Number of pushers in dictionary `pushers`. """ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f95627ee61..90fd98c53e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -650,7 +650,7 @@ class RoomEventContextServlet(RestServlet): event_filter = None results = await self.room_context_handler.get_event_context( - requester.user, room_id, event_id, limit, event_filter + requester, room_id, event_id, limit, event_filter ) if not results: diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 5b5da71815..4fe712b30c 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,13 +16,24 @@ import logging from functools import wraps +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import GroupID +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.types import GroupID, JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -33,7 +44,7 @@ def _validate_group_id(f): """ @wraps(f) - def wrapper(self, request, group_id, *args, **kwargs): + def wrapper(self, request: Request, group_id: str, *args, **kwargs): if not GroupID.is_valid(group_id): raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) @@ -48,14 +59,14 @@ class GroupServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -66,11 +77,15 @@ class GroupServlet(RestServlet): return 200, group_description @_validate_group_id - async def on_POST(self, request, group_id): + async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert_params_in_dict( + content, ("name", "avatar_url", "short_description", "long_description") + ) + assert isinstance(self.groups_handler, GroupsLocalHandler) await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, category_id, room_id): + async def on_PUT( + self, request: Request, group_id: str, category_id: str, room_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, @@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, category_id, room_id): + async def on_DELETE( + self, request: Request, group_id: str, category_id: str, room_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet): "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id, category_id): + async def on_GET( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet): return 200, category @_validate_group_id - async def on_PUT(self, request, group_id, category_id): + async def on_PUT( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) @@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, category_id): + async def on_DELETE( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id, role_id): + async def on_GET( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet): return 200, category @_validate_group_id - async def on_PUT(self, request, group_id, role_id): + async def on_PUT( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) @@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, role_id): + async def on_DELETE( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet): "/users/(?P<user_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, role_id, user_id): + async def on_PUT( + self, request: Request, group_id: str, role_id: str, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, @@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, role_id, user_id): + async def on_DELETE( + self, request: Request, group_id: str, role_id: str, user_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet): PATTERNS = client_patterns("/create_group$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet): "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, room_id): + async def on_PUT( + self, request: Request, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) @@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet): return 200, result @_validate_group_id - async def on_DELETE(self, request, group_id, room_id): + async def on_DELETE( + self, request: Request, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet): "/config/(?P<config_key>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, room_id, config_key): + async def on_PUT( + self, request: Request, group_id: str, room_id: str, config_key: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet): "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet): self.is_mine_id = hs.is_mine_id @_validate_group_id - async def on_PUT(self, request, group_id, user_id): + async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet): "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, user_id): + async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance(self.groups_handler, GroupsLocalHandler) result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request, user_id): + async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) result = await self.groups_handler.get_publicised_groups_for_user(user_id) @@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/joined_groups$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet): return 200, result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): GroupServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 10e1891174..e3d322f2ac 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -193,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): body, ["client_secret", "country", "phone_number", "send_attempt"] ) client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) country = body["country"] phone_number = body["phone_number"] send_attempt = body["send_attempt"] @@ -293,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet): sid = parse_string(request, "sid", required=True) client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) token = parse_string(request, "token", required=True) # Attempt to validate a 3PID session diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index bf030e0ff4..147920767f 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class RoomUpgradeRestServlet(RestServlet): - """Handler for room uprade requests. + """Handler for room upgrade requests. Handles requests of the form: diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index f71a03a12d..90bbeca679 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -137,7 +137,7 @@ def add_file_headers( # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` # is (essentially) a single US-ASCII word, and a `quoted-string` is a # US-ASCII string surrounded by double-quotes, using backslash as an - # escape charater. Note that %-encoding is *not* permitted. + # escape character. Note that %-encoding is *not* permitted. # # `filename*` is defined to be an `ext-value`, which is defined in # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 4c9946a616..635bccf775 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -184,7 +184,7 @@ class MediaRepository: async def get_local_media( self, request: Request, media_id: str, name: Optional[str] ) -> None: - """Responds to reqests for local media, if exists, or returns 404. + """Responds to requests for local media, if exists, or returns 404. Args: request: The incoming request. @@ -306,7 +306,7 @@ class MediaRepository: media_info = await self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already - # seen the file then reuse the existing ID, otherwise genereate a new + # seen the file then reuse the existing ID, otherwise generate a new # one. # If we have an entry in the DB, try and look for it @@ -927,10 +927,10 @@ class MediaRepositoryResource(Resource): <thumbnail> - The thumbnail methods are "crop" and "scale". "scale" trys to return an + The thumbnail methods are "crop" and "scale". "scale" tries to return an image where either the width or the height is smaller than the requested size. The client should then scale and letterbox the image if it needs to - fit within a given rectangle. "crop" trys to return an image where the + fit within a given rectangle. "crop" tries to return an image where the width and height are close to the requested size and the aspect matches the requested size. The client should scale the image if it needs to fit within a given rectangle. diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 89cdd605aa..aba6d689a8 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -16,13 +16,17 @@ import contextlib import logging import os import shutil -from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence + +import attr from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from synapse.api.errors import NotFoundError from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder @@ -58,6 +62,8 @@ class MediaStorage: self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other @@ -127,18 +133,29 @@ class MediaStorage: f.flush() f.close() + spam = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + raise SpamMediaException() + for provider in self.storage_providers: await provider.store_file(path, file_info) finished_called[0] = True yield f, fname, finish - except Exception: + except Exception as e: try: os.remove(fname) except Exception: pass - raise + + raise e from None if not finished_called: raise Exception("Finished callback not called") @@ -302,3 +319,39 @@ class FileResponder(Responder): def __exit__(self, exc_type, exc_val, exc_tb): self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2 ** 14 + + clock = attr.ib(type=Clock) + path = attr.ib(type=str) + + async def write_chunks_to(self, callback: Callable[[bytes], None]): + """Reads the file in chunks and calls the callback with each chunk. + """ + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index bf3be653aa..ae53b1d23f 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -58,7 +58,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) +_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I) +_xml_encoding_match = re.compile( + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I +) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) OG_TAG_NAME_MAXLEN = 50 @@ -300,24 +303,7 @@ class PreviewUrlResource(DirectServeJsonResource): with open(media_info["filename"], "rb") as file: body = file.read() - encoding = None - - # Let's try and figure out if it has an encoding set in a meta tag. - # Limit it to the first 1kb, since it ought to be in the meta tags - # at the top. - match = _charset_match.search(body[:1000]) - - # If we find a match, it should take precedence over the - # Content-Type header, so set it here. - if match: - encoding = match.group(1).decode("ascii") - - # If we don't find a match, we'll look at the HTTP Content-Type, and - # if that doesn't exist, we'll fall back to UTF-8. - if not encoding: - content_match = _content_type_match.match(media_info["media_type"]) - encoding = content_match.group(1) if content_match else "utf-8" - + encoding = get_html_media_encoding(body, media_info["media_type"]) og = decode_and_calc_og(body, media_info["uri"], encoding) # pre-cache the image for posterity @@ -689,6 +675,48 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") +def get_html_media_encoding(body: bytes, content_type: str) -> str: + """ + Get the encoding of the body based on the (presumably) HTML body or media_type. + + The precedence used for finding a character encoding is: + + 1. meta tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to UTF-8. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Let's try and figure out if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + return match.group(1).decode("ascii") + + # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/> + + # If we didn't find a match, see if it an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + return match.group(1).decode("ascii") + + # If we don't find a match, we'll look at the HTTP Content-Type, and + # if that doesn't exist, we'll fall back to UTF-8. + content_match = _content_type_match.match(content_type) + if content_match: + return content_match.group(1) + + return "utf-8" + + def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None ) -> Dict[str, Optional[str]]: @@ -725,6 +753,11 @@ def decode_and_calc_og( def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: # Attempt to parse the body. If this fails, log and return no metadata. tree = etree.fromstring(body_attempt, parser) + + # The data was successfully parsed, but no tree was found. + if tree is None: + return {} + return _calc_og(tree, media_uri) # Attempt to parse the body. If this fails, log and return no metadata. diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 6da76ae994..1136277794 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -22,6 +22,7 @@ from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +from synapse.rest.media.v1.media_storage import SpamMediaException if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion - content_uri = await self.media_repo.create_content( - media_type, upload_name, request.content, content_length, requester.user - ) + try: + content_uri = await self.media_repo.create_content( + media_type, upload_name, request.content, content_length, requester.user + ) + except SpamMediaException: + # For uploading of media we want to respond with a 400, instead of + # the default 404, as that would just be confusing. + raise SynapseError(400, "Bad content") logger.info("Uploaded content with URI %r", content_uri) diff --git a/synapse/server.py b/synapse/server.py index 9bdd3177d7..6b3892e3cd 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,17 @@ import abc import functools import logging import os -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + TypeVar, + Union, + cast, +) import twisted.internet.base import twisted.internet.tcp @@ -588,7 +598,9 @@ class HomeServer(metaclass=abc.ABCMeta): return UserDirectoryHandler(self) @cache_in_self - def get_groups_local_handler(self): + def get_groups_local_handler( + self, + ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]: if self.config.worker_app: return GroupsLocalWorkerHandler(self) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3bd9ff8ca0..28544ccb92 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -615,7 +615,7 @@ class StateResolutionHandler: event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_map_factory. If None, all events will be fetched via state_res_store. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d2ba4bd2fc..ae4bf1a54f 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -158,8 +158,8 @@ class LoggingDatabaseConnection: def commit(self) -> None: self.conn.commit() - def rollback(self, *args, **kwargs) -> None: - self.conn.rollback(*args, **kwargs) + def rollback(self) -> None: + self.conn.rollback() def __enter__(self) -> "Connection": self.conn.__enter__() @@ -244,12 +244,15 @@ class LoggingTransaction: assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) + def fetchone(self) -> Optional[Tuple]: + return self.txn.fetchone() + + def fetchmany(self, size: Optional[int] = None) -> List[Tuple]: + return self.txn.fetchmany(size=size) + def fetchall(self) -> List[Tuple]: return self.txn.fetchall() - def fetchone(self) -> Tuple: - return self.txn.fetchone() - def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() @@ -754,6 +757,7 @@ class DatabasePool: Returns: A list of dicts where the key is the column header. """ + assert cursor.description is not None, "cursor.description was None" col_headers = [intern(str(column[0])) for column in cursor.description] results = [dict(zip(col_headers, row)) for row in cursor] return results diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 31f70ac5ef..45ca6620a8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -450,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): }, ) - # Add the messages to the approriate local device inboxes so that + # Add the messages to the appropriate local device inboxes so that # they'll be sent to the devices when they next sync. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 8326640d20..ddfb13e3ad 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -371,7 +371,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # and state sets {A} and {B} then walking the auth chains of A and B # would immediately show that C is reachable by both. However, if we # stopped at C then we'd only reach E via the auth chain of B and so E - # would errornously get included in the returned difference. + # would erroneously get included in the returned difference. # # The other thing that we do is limit the number of auth chains we walk # at once, due to practical limits (i.e. we can only query the database @@ -497,7 +497,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas a_ids = new_aids - # Mark that the auth event is reachable by the approriate sets. + # Mark that the auth event is reachable by the appropriate sets. sets.intersection_update(event_to_missing_sets[event_id]) search.sort() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ccda9f1caa..7abfb9112e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1050,7 +1050,7 @@ class PersistEventsStore: # Figure out the changes of membership to invalidate the # `get_rooms_for_user` cache. # We find out which membership events we may have deleted - # and which we have added, then we invlidate the caches for all + # and which we have added, then we invalidate the caches for all # those users. members_changed = { state_key diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 04ac2d0ced..e97026dc2e 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -155,7 +155,7 @@ class KeyStore(SQLBaseStore): (server_name, key_id, from_server) triplet if one already existed. Args: server_name: The name of the server. - key_id: The identifer of the key this JSON is for. + key_id: The identifier of the key this JSON is for. from_server: The server this JSON was fetched from. ts_now_ms: The time now in milliseconds. ts_valid_until_ms: The time when this json stops being valid. @@ -182,7 +182,7 @@ class KeyStore(SQLBaseStore): async def get_server_keys_json( self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: - """Retrive the key json for a list of server_keys and key ids. + """Retrieve the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. The JSON is returned as a byte array so that it can be efficiently diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 92e65aa640..614a418a15 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -111,7 +111,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_sent_e2ee_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. + # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname sql = """ @@ -167,7 +167,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_sent_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. + # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname sql = """ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e4843a202c..ae9283f52d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -160,7 +160,7 @@ class ReceiptsWorkerStore(SQLBaseStore): Args: room_id: List of room_ids. - to_key: Max stream id to fetch receipts upto. + to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. @@ -189,7 +189,7 @@ class ReceiptsWorkerStore(SQLBaseStore): Args: room_ids: The room id. - to_key: Max stream id to fetch receipts upto. + to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. @@ -312,7 +312,7 @@ class ReceiptsWorkerStore(SQLBaseStore): to a limit of the latest 100 read receipts. Args: - to_key: Max stream id to fetch receipts upto. + to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a9fcb5f59c..cba343aa68 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1044,7 +1044,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): async def _background_add_rooms_room_version_column( self, progress: dict, batch_size: int ): - """Background update to go and add room version inforamtion to `rooms` + """Background update to go and add room version information to `rooms` table from `current_state_events` table. """ diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite index a0411ede7e..308124e531 100644 --- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite @@ -67,11 +67,6 @@ CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT CREATE INDEX user_threepids_user_id ON user_threepids(user_id); CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) /* event_search(event_id,room_id,sender,"key",value) */; -CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); -CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB); CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) ); CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) ); CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); @@ -149,11 +144,6 @@ CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_las CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value ) /* user_directory_search(user_id,value) */; -CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value'); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB); CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL ); CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT ); diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 356623fc6e..0dbb501f16 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -64,7 +64,7 @@ class StateDeltasStore(SQLBaseStore): def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than # N results. - # We arbitarily limit to 100 stream_id entries to ensure we don't + # We arbitrarily limit to 100 stream_id entries to ensure we don't # select toooo many. sql = """ SELECT stream_id, count(*) @@ -81,7 +81,7 @@ class StateDeltasStore(SQLBaseStore): for stream_id, count in txn: total += count if total > 100: - # We arbitarily limit to 100 entries to ensure we don't + # We arbitrarily limit to 100 entries to ensure we don't # select toooo many. logger.debug( "Clipping current_state_delta_stream rows to stream_id %i", diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index cea595ff19..248a6c3f25 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -198,7 +198,7 @@ class TransactionStore(TransactionWorkerStore): retry_interval: int, ) -> None: """Sets the current retry timings for a given destination. - Both timings should be zero if retrying is no longer occuring. + Both timings should be zero if retrying is no longer occurring. Args: destination diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index acb24e33af..1fd333b707 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -27,7 +27,7 @@ MAX_STATE_DELTA_HOPS = 100 class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud + """Defines functions related to state groups needed to run the state background updates. """ diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 035f9ea6e9..d15ccfacde 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine @@ -28,11 +27,8 @@ def create_engine(database_config) -> BaseDatabaseEngine: return Sqlite3Engine(sqlite3, database_config) if name == "psycopg2": - # pypy requires psycopg2cffi rather than psycopg2 - if platform.python_implementation() == "PyPy": - import psycopg2cffi as psycopg2 # type: ignore - else: - import psycopg2 # type: ignore + # Note that psycopg2cffi-compat provides the psycopg2 module on pypy. + import psycopg2 # type: ignore return PostgresEngine(psycopg2, database_config) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 5db0f0b520..b3d1834efb 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform import struct import threading import typing @@ -30,6 +31,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): database = database_config.get("args", {}).get("database") self._is_in_memory = database in (None, ":memory:",) + if platform.python_implementation() == "PyPy": + # pypy's sqlite3 module doesn't handle bytearrays, convert them + # back to bytes. + database_module.register_adapter(bytearray, lambda array: bytes(array)) + # The current max state_group, or None if we haven't looked # in the DB yet. self._current_state_group_id = None diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 566ea19bae..cd30e6b80a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -113,7 +113,7 @@ def prepare_database( # which should be empty. if config is None: raise ValueError( - "config==None in prepare_database, but databse is not empty" + "config==None in prepare_database, but database is not empty" ) # if it's a worker app, refuse to upgrade the database, to avoid multiple @@ -619,9 +619,9 @@ def _get_or_create_schema_state( txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() - current_version = int(row[0]) if row else None - if current_version: + if row is not None: + current_version = int(row[0]) txn.execute( "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 9cadcba18f..17291c9d5e 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union from typing_extensions import Protocol @@ -20,23 +20,44 @@ from typing_extensions import Protocol Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ +_Parameters = Union[Sequence[Any], Mapping[str, Any]] + class Cursor(Protocol): - def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + def execute(self, sql: str, parameters: _Parameters = ...) -> Any: ... - def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any: ... - def fetchall(self) -> List[Tuple]: + def fetchone(self) -> Optional[Tuple]: + ... + + def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: ... - def fetchone(self) -> Tuple: + def fetchall(self) -> List[Tuple]: ... @property - def description(self) -> Any: - return None + def description( + self, + ) -> Optional[ + Sequence[ + # Note that this is an approximate typing based on sqlite3 and other + # drivers, and may not be entirely accurate. + Tuple[ + str, + Optional[Any], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + ] + ] + ]: + ... @property def rowcount(self) -> int: @@ -59,7 +80,7 @@ class Connection(Protocol): def commit(self) -> None: ... - def rollback(self, *args, **kwargs) -> None: + def rollback(self) -> None: ... def __enter__(self) -> "Connection": diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 71ef5a72dc..9dd537bf66 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -245,7 +245,7 @@ class MultiWriterIdGenerator: # and b) noting that if we have seen a run of persisted positions # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7). # - # Note: There is no guarentee that the IDs generated by the sequence + # Note: There is no guarantee that the IDs generated by the sequence # will be gapless; gaps can form when e.g. a transaction was rolled # back. This means that sometimes we won't be able to skip forward the # position even though everything has been persisted. However, since @@ -418,7 +418,7 @@ class MultiWriterIdGenerator: # bother, as nothing will read it). # # We only do this on the success path so that the persisted current - # position points to a persited row with the correct instance name. + # position points to a persisted row with the correct instance name. if self._writers: txn.call_after( run_as_background_process, @@ -509,7 +509,7 @@ class MultiWriterIdGenerator: } def advance(self, instance_name: str, new_id: int): - """Advance the postion of the named writer to the given ID, if greater + """Advance the position of the named writer to the given ID, if greater than existing entry. """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 0ec4dc2918..e2b316a218 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator): def get_next_id_txn(self, txn: Cursor) -> int: txn.execute("SELECT nextval(?)", (self._sequence_name,)) - return txn.fetchone()[0] + fetch_res = txn.fetchone() + assert fetch_res is not None + return fetch_res[0] def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: txn.execute( @@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator): txn.execute( "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} ) - last_value, is_called = txn.fetchone() + fetch_res = txn.fetchone() + assert fetch_res is not None + last_value, is_called = fetch_res # If we have an associated stream check the stream_positions table. max_in_stream_positions = None diff --git a/synapse/types.py b/synapse/types.py index eafe729dfe..c695558a86 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -675,7 +675,7 @@ class PersistedEventPosition: persisted in the same room after this position will be after the returned `RoomStreamToken`. - Note: no guarentees are made about ordering w.r.t. events in other + Note: no guarantees are made about ordering w.r.t. events in other rooms. """ # Doing the naive thing satisfies the desired properties described in diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 9a873c8e8e..691dde9a01 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -497,7 +497,7 @@ def timeout_deferred( delayed_call = reactor.callLater(timeout, time_it_out) def convert_cancelled(value: failure.Failure): - # if the orgininal deferred was cancelled, and our timeout has fired, then + # if the original deferred was cancelled, and our timeout has fired, then # the reason it was cancelled was due to our timeout. Turn the CancelledError # into a TimeoutError. if timed_out[0] and value.check(CancelledError): diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index f8038bf861..9ce7873ab5 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken -client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically @@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") rand = random.SystemRandom() -def random_string(length): +def random_string(length: int) -> str: return "".join(rand.choice(string.ascii_letters) for _ in range(length)) -def random_string_with_symbols(length): +def random_string_with_symbols(length: int) -> str: return "".join(rand.choice(_string_with_symbols) for _ in range(length)) -def is_ascii(s): - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True +def is_ascii(s: bytes) -> bool: + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False + return True -def assert_valid_client_secret(client_secret): - """Validate that a given string matches the client_secret regex defined by the spec""" - if client_secret_regex.match(client_secret) is None: +def assert_valid_client_secret(client_secret: str) -> None: + """Validate that a given string matches the client_secret defined by the spec""" + if ( + len(client_secret) <= 0 + or len(client_secret) > 255 + or CLIENT_SECRET_REGEX.match(client_secret) is None + ): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) diff --git a/synapse/visibility.py b/synapse/visibility.py index ec50e7e977..e39d02602a 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -80,6 +80,7 @@ async def filter_events_for_client( events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), @@ -233,7 +234,7 @@ async def filter_events_for_client( elif visibility == HistoryVisibility.SHARED and is_peeking: # if the visibility is shared, users cannot see the event unless - # they have *subequently* joined the room (or were members at the + # they have *subsequently* joined the room (or were members at the # time, of course) # # XXX: if the user has subsequently joined and then left again, diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index ee5217b074..b1a8c58e1c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -17,8 +17,6 @@ from mock import Mock import pymacaroons -from twisted.internet import defer - from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -33,19 +31,17 @@ from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest -from tests.utils import mock_getRawHeaders, setup_test_homeserver +from tests.test_utils import simple_async_mock +from tests.utils import mock_getRawHeaders -class AuthTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.state_handler = Mock() +class AuthTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = Mock() - self.hs = yield setup_test_homeserver(self.addCleanup) - self.hs.get_datastore = Mock(return_value=self.store) - self.hs.get_auth_handler().store = self.store - self.auth = Auth(self.hs) + hs.get_datastore = Mock(return_value=self.store) + hs.get_auth_handler().store = self.store + self.auth = Auth(hs) # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -57,64 +53,59 @@ class AuthTestCase(unittest.TestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) - self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) - self.store.is_support_user = Mock(return_value=defer.succeed(False)) + self.store.insert_client_ip = simple_async_mock(None) + self.store.is_support_user = simple_async_mock(False) - @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) - self.store.get_user_by_access_token = Mock( - return_value=defer.succeed(user_info) - ) + self.store.get_user_by_access_token = simple_async_mock(user_info) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, InvalidClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), InvalidClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = TokenLookupResult(user_id=self.test_user, token_id=5) - self.store.get_user_by_access_token = Mock( - return_value=defer.succeed(user_info) - ) + self.store.get_user_by_access_token = simple_async_mock(user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, MissingClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), MissingClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): app_service = Mock( token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) - @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_good_ip(self): from netaddr import IPSet @@ -125,13 +116,13 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): @@ -144,42 +135,44 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, InvalidClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), InvalidClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, InvalidClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), InvalidClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - f = self.failureResultOf(d, MissingClientTokenError).value + f = self.get_failure( + self.auth.get_user_by_req(request), MissingClientTokenError + ).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") - @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): masquerading_user_id = b"@doppelganger:matrix.org" app_service = Mock( @@ -188,17 +181,15 @@ class AuthTestCase(unittest.TestCase): app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = Mock( - return_value=defer.succeed({"is_guest": False}) - ) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) + requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEquals( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -210,22 +201,18 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = defer.ensureDeferred(self.auth.get_user_by_req(request)) - self.failureResultOf(d, AuthError) + self.get_failure(self.auth.get_user_by_req(request), AuthError) - @defer.inlineCallbacks def test_get_user_from_macaroon(self): - self.store.get_user_by_access_token = Mock( - return_value=defer.succeed( - TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") - ) + self.store.get_user_by_access_token = simple_async_mock( + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) user_id = "@baldrick:matrix.org" @@ -237,7 +224,7 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield defer.ensureDeferred( + user_info = self.get_success( self.auth.get_user_by_access_token(macaroon.serialize()) ) self.assertEqual(user_id, user_info.user_id) @@ -246,10 +233,9 @@ class AuthTestCase(unittest.TestCase): # from the db. self.assertEqual(user_info.device_id, "device") - @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) - self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) + self.store.get_user_by_id = simple_async_mock({"is_guest": True}) + self.store.get_user_by_access_token = simple_async_mock(None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -263,20 +249,17 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield defer.ensureDeferred( - self.auth.get_user_by_access_token(serialized) - ) + user_info = self.get_success(self.auth.get_user_by_access_token(serialized)) self.assertEqual(user_id, user_info.user_id) self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) - @defer.inlineCallbacks def test_cannot_use_regular_token_as_guest(self): USER_ID = "@percy:matrix.org" - self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) - self.store.get_device = Mock(return_value=defer.succeed(None)) + self.store.add_access_token_to_user = simple_async_mock(None) + self.store.get_device = simple_async_mock(None) - token = yield defer.ensureDeferred( + token = self.get_success( self.hs.get_auth_handler().get_access_token_for_user_id( USER_ID, "DEVICE", valid_until_ms=None ) @@ -289,25 +272,21 @@ class AuthTestCase(unittest.TestCase): puppets_user_id=None, ) - def get_user(tok): + async def get_user(tok): if token != tok: - return defer.succeed(None) - return defer.succeed( - TokenLookupResult( - user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", - ) + return None + return TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock( - return_value=defer.succeed({"is_guest": False}) - ) + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) # check the token works request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield defer.ensureDeferred( + requester = self.get_success( self.auth.get_user_by_req(request, allow_guest=True) ) self.assertEqual(UserID.from_string(USER_ID), requester.user) @@ -323,17 +302,16 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [guest_tok.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - with self.assertRaises(InvalidClientCredentialsError) as cm: - yield defer.ensureDeferred( - self.auth.get_user_by_req(request, allow_guest=True) - ) + cm = self.get_failure( + self.auth.get_user_by_req(request, allow_guest=True), + InvalidClientCredentialsError, + ) - self.assertEqual(401, cm.exception.code) - self.assertEqual("Guest access token used for regular user", cm.exception.msg) + self.assertEqual(401, cm.value.code) + self.assertEqual("Guest access token used for regular user", cm.value.msg) self.store.get_user_by_id.assert_called_with(USER_ID) - @defer.inlineCallbacks def test_blocking_mau(self): self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._max_mau_value = 50 @@ -341,77 +319,61 @@ class AuthTestCase(unittest.TestCase): small_number_of_users = 1 # Ensure no error thrown - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.get_success(self.auth.check_auth_blocking()) self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(lots_of_users) - ) + self.store.get_monthly_active_count = simple_async_mock(lots_of_users) - with self.assertRaises(ResourceLimitError) as e: - yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) + e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEquals(e.value.code, 403) # Ensure does not throw an error - self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(small_number_of_users) - ) - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) + self.get_success(self.auth.check_auth_blocking()) - @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + self.store.get_monthly_active_count = simple_async_mock(100) # Support users allowed - yield defer.ensureDeferred( - self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) - ) - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)) + self.store.get_monthly_active_count = simple_async_mock(100) # Bots not allowed - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth.check_auth_blocking(user_type=UserTypes.BOT) - ) - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) + self.get_failure( + self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError + ) + self.store.get_monthly_active_count = simple_async_mock(100) # Real users not allowed - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - @defer.inlineCallbacks def test_reserved_threepid(self): self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 - self.store.get_monthly_active_count = lambda: defer.succeed(2) + self.store.get_monthly_active_count = simple_async_mock(2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred(self.auth.check_auth_blocking()) + self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth.check_auth_blocking(threepid=unknown_threepid) - ) + self.get_failure( + self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError + ) - yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) + self.get_success(self.auth.check_auth_blocking(threepid=threepid)) - @defer.inlineCallbacks def test_hs_disabled(self): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - with self.assertRaises(ResourceLimitError) as e: - yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) + e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEquals(e.value.code, 403) - @defer.inlineCallbacks def test_hs_disabled_no_server_notices_user(self): """Check that 'hs_disabled_message' works correctly when there is no server_notices user. @@ -422,16 +384,14 @@ class AuthTestCase(unittest.TestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - with self.assertRaises(ResourceLimitError) as e: - yield defer.ensureDeferred(self.auth.check_auth_blocking()) - self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.assertEquals(e.exception.code, 403) + e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEquals(e.value.code, 403) - @defer.inlineCallbacks def test_server_notices_mxid_special_cased(self): self.auth_blocking._hs_disabled = True user = "@user:server" self.auth_blocking._server_notices_mxid = user self.auth_blocking._hs_disabled_message = "Reason for being disabled" - yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) + self.get_success(self.auth.check_auth_blocking(user)) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 279c94a03d..ab7d290724 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -18,15 +18,12 @@ import jsonschema -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import make_event_from_dict from tests import unittest -from tests.utils import setup_test_homeserver user_localpart = "test_user" @@ -39,9 +36,8 @@ def MockEvent(**kwargs): return make_event_from_dict(kwargs) -class FilteringTestCase(unittest.TestCase): - def setUp(self): - hs = setup_test_homeserver(self.addCleanup) +class FilteringTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.filtering = hs.get_filtering() self.datastore = hs.get_datastore() @@ -351,10 +347,9 @@ class FilteringTestCase(unittest.TestCase): self.assertTrue(Filter(definition).check(event)) - @defer.inlineCallbacks def test_filter_presence_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -362,7 +357,7 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) @@ -371,11 +366,10 @@ class FilteringTestCase(unittest.TestCase): results = user_filter.filter_presence(events=events) self.assertEquals(events, results) - @defer.inlineCallbacks def test_filter_presence_no_match(self): user_filter_json = {"presence": {"types": ["m.*"]}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart + "2", user_filter=user_filter_json ) @@ -387,7 +381,7 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart + "2", filter_id=filter_id ) @@ -396,10 +390,9 @@ class FilteringTestCase(unittest.TestCase): results = user_filter.filter_presence(events=events) self.assertEquals([], results) - @defer.inlineCallbacks def test_filter_room_state_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -407,7 +400,7 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) @@ -416,10 +409,9 @@ class FilteringTestCase(unittest.TestCase): results = user_filter.filter_room_state(events=events) self.assertEquals(events, results) - @defer.inlineCallbacks def test_filter_room_state_no_match(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -429,7 +421,7 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield defer.ensureDeferred( + user_filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) @@ -454,11 +446,10 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) - @defer.inlineCallbacks def test_add_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.filtering.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) @@ -468,7 +459,7 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals( user_filter_json, ( - yield defer.ensureDeferred( + self.get_success( self.datastore.get_user_filter( user_localpart=user_localpart, filter_id=0 ) @@ -476,17 +467,16 @@ class FilteringTestCase(unittest.TestCase): ), ) - @defer.inlineCallbacks def test_get_filter(self): user_filter_json = {"room": {"state": {"types": ["m.*"]}}} - filter_id = yield defer.ensureDeferred( + filter_id = self.get_success( self.datastore.add_user_filter( user_localpart=user_localpart, user_filter=user_filter_json ) ) - filter = yield defer.ensureDeferred( + filter = self.get_success( self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id ) diff --git a/tests/config/test_server.py b/tests/config/test_server.py index a10d017120..98af7aa675 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -15,7 +15,8 @@ import yaml -from synapse.config.server import ServerConfig, is_threepid_reserved +from synapse.config._base import ConfigError +from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved from tests import unittest @@ -128,3 +129,61 @@ class ServerConfigTestCase(unittest.TestCase): ) self.assertEqual(conf["listeners"], expected_listeners) + + +class GenerateIpSetTestCase(unittest.TestCase): + def test_empty(self): + ip_set = generate_ip_set(()) + self.assertFalse(ip_set) + + ip_set = generate_ip_set((), ()) + self.assertFalse(ip_set) + + def test_generate(self): + """Check adding IPv4 and IPv6 addresses.""" + # IPv4 address + ip_set = generate_ip_set(("1.2.3.4",)) + self.assertEqual(len(ip_set.iter_cidrs()), 4) + + # IPv4 CIDR + ip_set = generate_ip_set(("1.2.3.4/24",)) + self.assertEqual(len(ip_set.iter_cidrs()), 4) + + # IPv6 address + ip_set = generate_ip_set(("2001:db8::8a2e:370:7334",)) + self.assertEqual(len(ip_set.iter_cidrs()), 1) + + # IPv6 CIDR + ip_set = generate_ip_set(("2001:db8::/104",)) + self.assertEqual(len(ip_set.iter_cidrs()), 1) + + # The addresses can overlap OK. + ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4")) + self.assertEqual(len(ip_set.iter_cidrs()), 4) + + def test_extra(self): + """Extra IP addresses are treated the same.""" + ip_set = generate_ip_set((), ("1.2.3.4",)) + self.assertEqual(len(ip_set.iter_cidrs()), 4) + + ip_set = generate_ip_set(("1.1.1.1",), ("1.2.3.4",)) + self.assertEqual(len(ip_set.iter_cidrs()), 8) + + # They can duplicate without error. + ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",)) + self.assertEqual(len(ip_set.iter_cidrs()), 4) + + def test_bad_value(self): + """An error should be raised if a bad value is passed in.""" + with self.assertRaises(ConfigError): + generate_ip_set(("not-an-ip",)) + + with self.assertRaises(ConfigError): + generate_ip_set(("1.2.3.4/128",)) + + with self.assertRaises(ConfigError): + generate_ip_set((":::",)) + + # The following get treated as empty data. + self.assertFalse(generate_ip_set(None)) + self.assertFalse(generate_ip_set({})) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 53763cd0f9..d5d3fdd99a 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore.return_value = self.mock_store - self.mock_store.get_received_ts.return_value = defer.succeed(0) - self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None) + self.mock_store.get_received_ts.return_value = make_awaitable(0) + self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice(is_interested=False), ] - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed([]) + self.mock_store.get_user_by_id.return_value = make_awaitable([]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed(None) + self.mock_store.get_user_by_id.return_value = make_awaitable(None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id}) + self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): "query_user called when it shouldn't have been.", ) - @defer.inlineCallbacks def test_query_room_alias_exists(self): room_alias_str = "#foo:bar" room_alias = Mock() @@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): Mock(room_id=room_id, servers=servers) ) - result = yield defer.ensureDeferred( - self.handler.query_room_alias_exists(room_alias) + result = self.successResultOf( + defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias)) ) self.mock_as_api.query_alias.assert_called_once_with( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index e24ce81284..0e42013bb9 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -16,28 +16,21 @@ from mock import Mock import pymacaroons -from twisted.internet import defer - -import synapse -import synapse.api.errors -from synapse.api.errors import ResourceLimitError +from synapse.api.errors import AuthError, ResourceLimitError from tests import unittest from tests.test_utils import make_awaitable -from tests.utils import setup_test_homeserver -class AuthTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.auth_handler = self.hs.get_auth_handler() - self.macaroon_generator = self.hs.get_macaroon_generator() +class AuthTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.auth_handler = hs.get_auth_handler() + self.macaroon_generator = hs.get_macaroon_generator() # MAU tests # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking = hs.get_auth()._auth_blocking self.auth_blocking._max_mau_value = 50 self.small_number_of_users = 1 @@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.get_clock().now = 5000 - token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase): v.satisfy_general(verify_nonce) v.verify(macaroon, self.hs.config.macaroon_secret_key) - @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): - self.hs.get_clock().now = 1000 - token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield defer.ensureDeferred( + user_id = self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected - self.hs.get_clock().now = 6000 - with self.assertRaises(synapse.api.errors.AuthError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id(token) - ) + self.reactor.advance(6) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token), + AuthError, + ) - @defer.inlineCallbacks def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield defer.ensureDeferred( + user_id = self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( macaroon.serialize() ) @@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase): # user_id. macaroon.add_first_party_caveat("user_id = b_user") - with self.assertRaises(synapse.api.errors.AuthError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() - ) - ) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ), + AuthError, + ) - @defer.inlineCallbacks def test_mau_limits_disabled(self): self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception - yield defer.ensureDeferred( + self.get_success( self.auth_handler.get_access_token_for_user_id( "user_a", device_id=None, valid_until_ms=None ) ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) ) - @defer.inlineCallbacks def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None - ) - ) + self.get_failure( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ), + ResourceLimitError, + ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() - ) - ) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ), + ResourceLimitError, + ) - @defer.inlineCallbacks def test_mau_limits_parity(self): + # Ensure we're not at the unix epoch. + self.reactor.advance(1) self.auth_blocking._limit_usage_by_mau = True - # If not in monthly active cohort + # Set the server to be at the edge of too many users. self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.auth_blocking._max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None - ) - ) - self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) + # If not in monthly active cohort + self.get_failure( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ), + ResourceLimitError, ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() - ) - ) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ), + ResourceLimitError, + ) + # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.clock.time_msec()) ) - self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) - ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.get_access_token_for_user_id( "user_a", device_id=None, valid_until_ms=None ) ) - self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.hs.get_clock().time_msec()) - ) - self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) - ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) ) - @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True @@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase): return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception - yield defer.ensureDeferred( + self.get_success( self.auth_handler.get_access_token_for_user_id( "user_a", device_id=None, valid_until_ms=None ) @@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 7baf224f7e..6f992291b8 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -16,7 +16,7 @@ from mock import Mock from synapse.handlers.cas_handler import CasResponse from tests.test_utils import simple_async_mock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. BASE_URL = "https://synapse/" @@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase): "server_url": SERVER_URL, "service_url": BASE_URL, } + + # Update this config with what's in the default config so that + # override_config works as expected. + cas_config.update(config.get("cas_config", {})) config["cas_config"] = cas_config return config @@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase): "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True ) + @override_config( + { + "cas_config": { + "required_attributes": {"userGroup": "staff", "department": None} + } + } + ) + def test_required_attributes(self): + """The required attributes must be met from the CAS response.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # The response doesn't have the proper userGroup or department. + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_not_called() + + # The response doesn't have any department. + cas_response = CasResponse("test_user", {"userGroup": "staff"}) + request.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_not_called() + + # Add the proper attributes and it should succeed. + cas_response = CasResponse( + "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]} + ) + request.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None, new_user=True + ) + def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader"]) + return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 924f29f051..c1a13aeb71 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -18,42 +18,27 @@ import mock from signedjson import key as key, sign as sign -from twisted.internet import defer - -import synapse.handlers.e2e_keys -import synapse.storage -from synapse.api import errors from synapse.api.constants import RoomEncryptionAlgorithms +from synapse.api.errors import Codes, SynapseError -from tests import unittest, utils +from tests import unittest -class E2eKeysHandlerTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer - self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler - self.store = None # type: synapse.storage.Storage +class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=mock.Mock()) - @defer.inlineCallbacks - def setUp(self): - self.hs = yield utils.setup_test_homeserver( - self.addCleanup, federation_client=mock.Mock() - ) - self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks def test_query_local_devices_no_devices(self): """If the user has no devices, we expect an empty list. """ local_user = "@boris:" + self.hs.hostname - res = yield defer.ensureDeferred( - self.handler.query_local_devices({local_user: None}) - ) + res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - @defer.inlineCallbacks def test_reupload_one_time_keys(self): """we should be able to re-upload the same keys""" local_user = "@boris:" + self.hs.hostname @@ -64,7 +49,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) @@ -73,14 +58,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) - @defer.inlineCallbacks def test_change_one_time_keys(self): """attempts to change one-time-keys should be rejected""" @@ -92,75 +76,64 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} - ) - ) - self.fail("No error when changing string key") - except errors.SynapseError: - pass - - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} - ) - ) - self.fail("No error when replacing dict key with string") - except errors.SynapseError: - pass - - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, - device_id, - {"one_time_keys": {"alg1:k1": {"key": "key"}}}, - ) - ) - self.fail("No error when replacing string key with dict") - except errors.SynapseError: - pass - - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, - ) - ) - self.fail("No error when replacing dict key") - except errors.SynapseError: - pass + # Error when changing string key + self.get_failure( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ), + SynapseError, + ) + + # Error when replacing dict key with strin + self.get_failure( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ), + SynapseError, + ) + + # Error when replacing string key with dict + self.get_failure( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ), + SynapseError, + ) + + # Error when replacing dict key + self.get_failure( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ), + SynapseError, + ) - @defer.inlineCallbacks def test_claim_one_time_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield defer.ensureDeferred( + res2 = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -173,7 +146,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_fallback_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" @@ -181,12 +153,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): otk = {"alg1:k2": "key2"} # we shouldn't have any unused fallback keys yet - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) self.assertEqual(res, []) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, device_id, @@ -195,14 +167,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) # we should now have an unused alg1 key - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) self.assertEqual(res, ["alg1"]) # claiming an OTK when no OTKs are available should return the fallback # key - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -213,13 +185,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) # we shouldn't have any unused fallback keys again - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) self.assertEqual(res, []) # claiming an OTK again should return the same fallback key - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -231,13 +203,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # if the user uploads a one-time key, the next claim should fetch the # one-time key, and then go back to the fallback - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": otk} ) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -246,7 +218,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -256,7 +228,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) - @defer.inlineCallbacks def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname @@ -270,9 +241,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys1) - ) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) keys2 = { "master_key": { @@ -284,16 +253,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys2) - ) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2)) - devices = yield defer.ensureDeferred( + devices = self.get_success( self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" local_user = "@boris:" + self.hs.hostname @@ -326,9 +292,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys1) - ) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { @@ -358,12 +322,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, "abc", {"device_keys": device_key_1} ) ) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, "def", {"device_keys": device_key_2} ) @@ -372,7 +336,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signatures_for_device_keys( local_user, {local_user: {"abc": device_key_1}} ) @@ -383,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signatures_for_device_keys( local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} ) @@ -391,7 +355,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield defer.ensureDeferred( + devices = self.get_success( self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] @@ -399,7 +363,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) - @defer.inlineCallbacks def test_self_signing_key_doesnt_show_up_as_device(self): """signing keys should be hidden when fetching a user's devices""" local_user = "@boris:" + self.hs.hostname @@ -413,29 +376,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys1) - ) - - res = None - try: - yield defer.ensureDeferred( - self.hs.get_device_handler().check_device_registered( - user_id=local_user, - device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", - initial_device_display_name="new display name", - ) - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 400) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) - res = yield defer.ensureDeferred( - self.handler.query_local_devices({local_user: None}) + e = self.get_failure( + self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ), + SynapseError, ) + res = e.value.code + self.assertEqual(res, 400) + + res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - @defer.inlineCallbacks def test_upload_signatures(self): """should check signatures that are uploaded""" # set up a user with cross-signing keys and a device. This user will @@ -458,7 +414,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"device_keys": device_key} ) @@ -501,7 +457,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) ) @@ -515,14 +471,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signing_keys_for_user( other_user, {"master_key": other_master_key} ) ) # test various signature failures (see below) - ret = yield defer.ensureDeferred( + ret = self.get_success( self.handler.upload_signatures_for_device_keys( local_user, { @@ -602,20 +558,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) user_failures = ret["failures"][local_user] + self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE) self.assertEqual( - user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE + user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE ) - self.assertEqual( - user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE - ) - self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND) other_user_failures = ret["failures"][other_user] + self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND) self.assertEqual( - other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND - ) - self.assertEqual( - other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN + other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN ) # test successful signatures @@ -623,7 +575,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield defer.ensureDeferred( + ret = self.get_success( self.handler.upload_signatures_for_device_keys( local_user, { @@ -636,7 +588,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield defer.ensureDeferred( + ret = self.get_success( self.handler.query_devices( {"device_keys": {local_user: [], other_user: []}}, 0, local_user ) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 45f201a399..58773a0c38 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -19,14 +19,9 @@ import copy import mock -from twisted.internet import defer +from synapse.api.errors import SynapseError -import synapse.api.errors -import synapse.handlers.e2e_room_keys -import synapse.storage -from synapse.api import errors - -from tests import unittest, utils +from tests import unittest # sample room_key data for use in the tests room_keys = { @@ -45,51 +40,39 @@ room_keys = { } -class E2eRoomKeysHandlerTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer - self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler +class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(replication_layer=mock.Mock()) - @defer.inlineCallbacks - def setUp(self): - self.hs = yield utils.setup_test_homeserver( - self.addCleanup, replication_layer=mock.Mock() - ) - self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) - self.local_user = "@boris:" + self.hs.hostname + def prepare(self, reactor, clock, hs): + self.handler = hs.get_e2e_room_keys_handler() + self.local_user = "@boris:" + hs.hostname - @defer.inlineCallbacks def test_get_missing_current_version_info(self): """Check that we get a 404 if we ask for info about the current version if there is no version. """ - res = None - try: - yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.get_version_info(self.local_user), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_get_missing_version_info(self): """Check that we get a 404 if we ask for info about a specific version if it doesn't exist. """ - res = None - try: - yield defer.ensureDeferred( - self.handler.get_version_info(self.local_user, "bogus_version") - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.get_version_info(self.local_user, "bogus_version"), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.create_version( self.local_user, { @@ -101,7 +84,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] self.assertIsInstance(version_etag, str) del res["etag"] @@ -116,9 +99,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield defer.ensureDeferred( - self.handler.get_version_info(self.local_user, "1") - ) + res = self.get_success(self.handler.get_version_info(self.local_user, "1")) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -132,7 +113,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.create_version( self.local_user, { @@ -144,7 +125,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -156,11 +137,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_update_version(self): """Check that we can update versions. """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -171,7 +151,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.update_version( self.local_user, version, @@ -185,7 +165,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -197,32 +177,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_update_missing_version(self): """Check that we get a 404 on updating nonexistent versions """ - res = None - try: - yield defer.ensureDeferred( - self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, - ) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_update_omitted_version(self): """Check that the update succeeds if the version is missing from the body """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -233,7 +209,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - yield defer.ensureDeferred( + self.get_success( self.handler.update_version( self.local_user, version, @@ -245,7 +221,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -257,11 +233,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_update_bad_version(self): """Check that we get a 400 if the version in the body doesn't match """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -272,52 +247,41 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = None - try: - yield defer.ensureDeferred( - self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, - ) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 400) - @defer.inlineCallbacks def test_delete_missing_version(self): """Check that we get a 404 on deleting nonexistent versions """ - res = None - try: - yield defer.ensureDeferred( - self.handler.delete_version(self.local_user, "1") - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.delete_version(self.local_user, "1"), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_delete_missing_current_version(self): """Check that we get a 404 on deleting nonexistent current version """ - res = None - try: - yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) - except errors.SynapseError as e: - res = e.code + e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.create_version( self.local_user, { @@ -329,36 +293,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, "1") # check we can delete it - yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) + self.get_success(self.handler.delete_version(self.local_user, "1")) # check that it's gone - res = None - try: - yield defer.ensureDeferred( - self.handler.get_version_info(self.local_user, "1") - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.get_version_info(self.local_user, "1"), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_get_missing_backup(self): """Check that we get a 404 on querying missing backup """ - res = None - try: - yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, "bogus_version") - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -369,33 +325,27 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest - @defer.inlineCallbacks def test_upload_room_keys_no_versions(self): """Check that we get a 404 on uploading keys when no versions are defined """ - res = None - try: - yield defer.ensureDeferred( - self.handler.upload_room_keys(self.local_user, "no_version", room_keys) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_upload_room_keys_bogus_version(self): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -406,22 +356,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = None - try: - yield defer.ensureDeferred( - self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys - ) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -432,7 +377,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -443,20 +388,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "2") - res = None - try: - yield defer.ensureDeferred( - self.handler.upload_room_keys(self.local_user, "1", room_keys) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError + ) + res = e.value.code self.assertEqual(res, 403) - @defer.inlineCallbacks def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -467,17 +408,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org" ) @@ -485,18 +424,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) ) self.assertDictEqual(res, room_keys) - @defer.inlineCallbacks def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -507,12 +445,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) # get the etag to compare to future versions - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -522,37 +460,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -560,28 +494,25 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here - @defer.inlineCallbacks def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -593,13 +524,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(version, "1") # check for bulk-delete - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - yield defer.ensureDeferred( - self.handler.delete_room_keys(self.local_user, version) - ) - res = yield defer.ensureDeferred( + self.get_success(self.handler.delete_room_keys(self.local_user, version)) + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) @@ -607,15 +536,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - yield defer.ensureDeferred( + self.get_success( self.handler.delete_room_keys( self.local_user, version, room_id="!abc:matrix.org" ) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) @@ -623,15 +552,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - yield defer.ensureDeferred( + self.get_success( self.handler.delete_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 022943a10a..787fab7875 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -13,25 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. - from mock import Mock -from twisted.internet import defer - import synapse.types from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID from tests import unittest from tests.test_utils import make_awaitable -from tests.utils import setup_test_homeserver -class ProfileTestCase(unittest.TestCase): +class ProfileTestCase(unittest.HomeserverTestCase): """ Tests profile management. """ - @defer.inlineCallbacks - def setUp(self): + def make_homeserver(self, reactor, clock): self.mock_federation = Mock() self.mock_registry = Mock() @@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler - hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, ) + return hs + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) + self.get_success(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() - self.hs = hs - @defer.inlineCallbacks def test_get_my_name(self): - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) - displayname = yield defer.ensureDeferred( - self.handler.get_displayname(self.frank) - ) + displayname = self.get_success(self.handler.get_displayname(self.frank)) self.assertEquals("Frank", displayname) - @defer.inlineCallbacks def test_set_my_name(self): - yield defer.ensureDeferred( + self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." ) @@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.frank.localpart) ) ), @@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase): ) # Set displayname again - yield defer.ensureDeferred( + self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank" ) @@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.frank.localpart) ) ), @@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase): ) # Set displayname to an empty string - yield defer.ensureDeferred( + self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "" ) ) self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_profile_displayname(self.frank.localpart) - ) - ) + (self.get_success(self.store.get_profile_displayname(self.frank.localpart))) ) - @defer.inlineCallbacks def test_set_my_name_if_disabled(self): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) self.assertEquals( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.frank.localpart) ) ), @@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase): ) # Setting displayname a second time is forbidden - d = defer.ensureDeferred( + self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." - ) + ), + SynapseError, ) - yield self.assertFailure(d, SynapseError) - - @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = defer.ensureDeferred( + self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.bob), "Frank Jr." - ) + ), + AuthError, ) - yield self.assertFailure(d, AuthError) - - @defer.inlineCallbacks def test_get_other_name(self): self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) - displayname = yield defer.ensureDeferred( - self.handler.get_displayname(self.alice) - ) + displayname = self.get_success(self.handler.get_displayname(self.alice)) self.assertEquals(displayname, "Alice") self.mock_federation.make_query.assert_called_with( @@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase): ignore_backoff=True, ) - @defer.inlineCallbacks def test_incoming_fed_query(self): - yield defer.ensureDeferred(self.store.create_profile("caroline")) - yield defer.ensureDeferred( - self.store.set_profile_displayname("caroline", "Caroline") - ) + self.get_success(self.store.create_profile("caroline")) + self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) - response = yield defer.ensureDeferred( + response = self.get_success( self.query_handlers["profile"]( {"user_id": "@caroline:test", "field": "displayname"} ) @@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals({"displayname": "Caroline"}, response) - @defer.inlineCallbacks def test_get_my_avatar(self): - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) ) - avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) + avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) self.assertEquals("http://my.server/me.png", avatar_url) - @defer.inlineCallbacks def test_set_my_avatar(self): - yield defer.ensureDeferred( + self.get_success( self.handler.set_avatar_url( self.frank, synapse.types.create_requester(self.frank), @@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/pic.gif", ) # Set avatar again - yield defer.ensureDeferred( + self.get_success( self.handler.set_avatar_url( self.frank, synapse.types.create_requester(self.frank), @@ -230,56 +201,42 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) # Set avatar to an empty string - yield defer.ensureDeferred( + self.get_success( self.handler.set_avatar_url( self.frank, synapse.types.create_requester(self.frank), "", ) ) self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), ) - @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) ) self.assertEquals( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) # Set avatar a second time is forbidden - d = defer.ensureDeferred( + self.get_failure( self.handler.set_avatar_url( self.frank, synapse.types.create_requester(self.frank), "http://my.server/pic.gif", - ) + ), + SynapseError, ) - - yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index a8d6c0f617..029af2853e 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase): ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + @override_config( + { + "saml2_config": { + "attribute_requirements": [ + {"attribute": "userGroup", "value": "staff"}, + {"attribute": "department", "value": "sales"}, + ], + }, + } + ) + def test_attribute_requirements(self): + """The required attributes must be met from the SAML response.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # The response doesn't have the proper userGroup or department. + saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + auth_handler.complete_sso_login.assert_not_called() + + # The response doesn't have the proper department. + saml_response = FakeAuthnResponse( + {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]} + ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + auth_handler.complete_sso_login.assert_not_called() + + # Add the proper attributes and it should succeed. + saml_response = FakeAuthnResponse( + { + "uid": "test_user", + "username": "test_user", + "userGroup": ["staff", "admin"], + "department": ["sales"], + } + ) + request.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None, new_user=True + ) + def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader"]) + return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index c4e1e7ed85..22f452ec24 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -124,13 +124,18 @@ class EmailPusherTests(HomeserverTestCase): ) self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) - # The other user sends some messages + # The other user sends a single message. self.helper.send(room, body="Hi!", tok=self.others[0].token) - self.helper.send(room, body="There!", tok=self.others[0].token) # We should get emailed about that message self._check_for_mail() + # The other user sends multiple messages. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + self.helper.send(room, body="There!", tok=self.others[0].token) + + self._check_for_mail() + def test_invite_sends_email(self): # Create a room and invite the user to it room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) @@ -217,6 +222,45 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about those messages self._check_for_mail() + def test_empty_room(self): + """All users leaving a room shouldn't cause the pusher to break.""" + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id + ) + self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) + + # The other user sends a single message. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + + # Leave the room before the message is processed. + self.helper.leave(room, self.user_id, tok=self.access_token) + self.helper.leave(room, self.others[0].id, tok=self.others[0].token) + + # We should get emailed about that message + self._check_for_mail() + + def test_empty_room_multiple_messages(self): + """All users leaving a room shouldn't cause the pusher to break.""" + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id + ) + self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) + + # The other user sends a single message. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + self.helper.send(room, body="There!", tok=self.others[0].token) + + # Leave the room before the message is processed. + self.helper.leave(room, self.user_id, tok=self.access_token) + self.helper.leave(room, self.others[0].id, tok=self.others[0].token) + + # We should get emailed about that message + self._check_for_mail() + def test_encrypted_message(self): room = self.helper.create_room_as(self.user_id, tok=self.access_token) self.helper.invite( @@ -269,3 +313,6 @@ class EmailPusherTests(HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) + + # Reset the attempts. + self.email_attempts = [] diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 7c47aa7e0a..2a217b1ce0 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1445,6 +1445,90 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + def test_context_as_non_admin(self): + """ + Test that, without being admin, one cannot use the context admin API + """ + # Create a room. + user_id = self.register_user("test", "test") + user_tok = self.login("test", "test") + + self.register_user("test_2", "test") + user_tok_2 = self.login("test_2", "test") + + room_id = self.helper.create_room_as(user_id, tok=user_tok) + + # Populate the room with events. + events = [] + for i in range(30): + events.append( + self.helper.send_event( + room_id, "com.example.test", content={"index": i}, tok=user_tok + ) + ) + + # Now attempt to find the context using the admin API without being admin. + midway = (len(events) - 1) // 2 + for tok in [user_tok, user_tok_2]: + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/context/%s" + % (room_id, events[midway]["event_id"]), + access_token=tok, + ) + self.assertEquals( + 403, int(channel.result["code"]), msg=channel.result["body"] + ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_context_as_admin(self): + """ + Test that, as admin, we can find the context of an event without having joined the room. + """ + + # Create a room. We're not part of it. + user_id = self.register_user("test", "test") + user_tok = self.login("test", "test") + room_id = self.helper.create_room_as(user_id, tok=user_tok) + + # Populate the room with events. + events = [] + for i in range(30): + events.append( + self.helper.send_event( + room_id, "com.example.test", content={"index": i}, tok=user_tok + ) + ) + + # Now let's fetch the context for this room. + midway = (len(events) - 1) // 2 + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/context/%s" + % (room_id, events[midway]["event_id"]), + access_token=self.admin_user_tok, + ) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals( + channel.json_body["event"]["event_id"], events[midway]["event_id"] + ) + + for i, found_event in enumerate(channel.json_body["events_before"]): + for j, posted_event in enumerate(events): + if found_event["event_id"] == posted_event["event_id"]: + self.assertTrue(j < midway) + break + else: + self.fail("Event %s from events_before not found" % j) + + for i, found_event in enumerate(channel.json_body["events_after"]): + for j, posted_event in enumerate(events): + if found_event["event_id"] == posted_event["event_id"]: + self.assertTrue(j > midway) + break + else: + self.fail("Event %s from events_after not found" % j) + class MakeRoomAdminTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index bfcb786af8..49543d9acb 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,7 +15,7 @@ import time import urllib.parse -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union from urllib.parse import urlencode from mock import Mock @@ -493,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class + html = channel.result["body"].decode("utf-8") p = TestHtmlParser() - p.feed(channel.result["body"].decode("utf-8")) + p.feed(html) p.close() - self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) + # there should be a link for each href + returned_idps = [] # type: List[str] + for link in p.links: + path, query = link.split("?", 1) + self.assertEqual(path, "pick_idp") + params = urllib.parse.parse_qs(query) + self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) + returned_idps.append(params["idp"][0]) - self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) + self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index e59fa70baa..f3448c94dd 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,163 +14,11 @@ # limitations under the License. """Tests REST events for /profile paths.""" -import json - -from mock import Mock - -from twisted.internet import defer - -import synapse.types -from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.rest.client.v1 import login, profile, room from tests import unittest -from ....utils import MockHttpResource, setup_test_homeserver - -myid = "@1234ABCD:test" -PATH_PREFIX = "/_matrix/client/r0" - - -class MockHandlerProfileTestCase(unittest.TestCase): - """ Tests rest layer of profile management. - - Todo: move these into ProfileTestCase - """ - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.mock_handler = Mock( - spec=[ - "get_displayname", - "set_displayname", - "get_avatar_url", - "set_avatar_url", - "check_profile_query_allowed", - ] - ) - - self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( - Mock() - ) - - hs = yield setup_test_homeserver( - self.addCleanup, - "test", - federation_http_client=None, - resource_for_client=self.mock_resource, - federation=Mock(), - federation_client=Mock(), - profile_handler=self.mock_handler, - ) - - async def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) - - hs.get_auth().get_user_by_req = _get_user_by_req - - profile.register_servlets(hs, self.mock_resource) - - @defer.inlineCallbacks - def test_get_my_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Frank") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Frank"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}' - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.") - - @defer.inlineCallbacks - def test_set_my_name_noauth(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = AuthError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@4567:test"), - b'{"displayname": "Frank Jr."}', - ) - - self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_other_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Bob") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Bob"}, response) - - @defer.inlineCallbacks - def test_set_other_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = SynapseError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@opaque:elsewhere"), - b'{"displayname":"bob"}', - ) - - self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_my_avatar(self): - mocked_get = self.mock_handler.get_avatar_url - mocked_get.return_value = defer.succeed("http://my.server/me.png") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/avatar_url" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"avatar_url": "http://my.server/me.png"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_avatar(self): - mocked_set = self.mock_handler.set_avatar_url - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/avatar_url" % (myid), - b'{"avatar_url": "http://my.server/pic.gif"}', - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") - class ProfileTestCase(unittest.HomeserverTestCase): @@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") + self.other = self.register_user("other", "pass", displayname="Bob") + + def test_get_displayname(self): + res = self._get_displayname() + self.assertEqual(res, "owner") def test_set_displayname(self): channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test"}), + content={"displayname": "test"}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 200, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "test") + def test_set_displayname_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test" * 100}), + content={"displayname": "test" * 100}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 400, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "owner") - def get_displayname(self): - channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) + def test_get_displayname_other(self): + res = self._get_displayname(self.other) + self.assertEquals(res, "Bob") + + def test_set_displayname_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.other,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_get_avatar_url(self): + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_set_avatar_url(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_avatar_url() + self.assertEqual(res, "http://my.server/pic.gif") + + def test_set_avatar_url_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_avatar_url_too_long(self): + """Attempts to set a stupid avatar_url should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_get_avatar_url_other(self): + res = self._get_avatar_url(self.other) + self.assertIsNone(res) + + def test_set_avatar_url_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.other,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def _get_displayname(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/displayname" % (name or self.owner,) + ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] + def _get_avatar_url(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/avatar_url" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body.get("avatar_url") + class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 38c51525a3..f6f3b9a356 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -18,8 +18,6 @@ from mock import Mock -from twisted.internet import defer - from synapse.rest.client.v1 import room from synapse.types import UserID @@ -60,32 +58,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_datastore().insert_client_ip = _insert_client_ip - def get_room_members(room_id): - if room_id == self.room_id: - return defer.succeed([self.user]) - else: - return defer.succeed([]) - - @defer.inlineCallbacks - def fetch_room_distributions_into( - room_id, localusers=None, remotedomains=None, ignore_user=None - ): - members = yield get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - - hs.get_room_member_handler().fetch_room_distributions_into = ( - fetch_room_distributions_into - ) - return hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py new file mode 100644 index 0000000000..7c22293d6d --- /dev/null +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet + +from tests import unittest +from tests.server import FakeChannel + + +class UpgradeRoomTest(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + self.creator = self.register_user("creator", "pass") + self.creator_token = self.login(self.creator, "pass") + + self.other = self.register_user("user", "pass") + self.other_token = self.login(self.other, "pass") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) + self.helper.join(self.room_id, self.other, tok=self.other_token) + + def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: + # We never want a cached response. + self.reactor.advance(5 * 60 + 1) + + return self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, + # This will upgrade a room to the same version, but that's fine. + content={"new_version": DEFAULT_ROOM_VERSION}, + access_token=token or self.creator_token, + ) + + def test_upgrade(self): + """ + Upgrading a room should work fine. + """ + channel = self._upgrade_room() + self.assertEquals(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + def test_not_in_room(self): + """ + Upgrading a room should work fine. + """ + # THe user isn't in the room. + roomless = self.register_user("roomless", "pass") + roomless_token = self.login(roomless, "pass") + + channel = self._upgrade_room(roomless_token) + self.assertEquals(403, channel.code, channel.result) + + def test_power_levels(self): + """ + Another user can upgrade the room if their power level is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.creator_token, + ) + power_levels["users"][self.other] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_user_default(self): + """ + Another user can upgrade the room if the default power level for users is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.creator_token, + ) + power_levels["users_default"] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_tombstone(self): + """ + Another user can upgrade the room if they can send the tombstone event. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.creator_token, + ) + power_levels["events"]["m.room.tombstone"] = 0 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + power_levels = self.helper.get_state( + self.room_id, "m.room.power_levels", tok=self.creator_token, + ) + self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index a6c6985173..c279eb49e3 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -30,6 +30,8 @@ from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable +from synapse.rest import admin +from synapse.rest.client.v1 import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage @@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from tests import unittest from tests.server import FakeSite, make_request +from tests.utils import default_config class MediaStorageTests(unittest.HomeserverTestCase): @@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase): headers.getRawHeaders(b"X-Robots-Tag"), [b"noindex, nofollow, noarchive, noimageindex"], ) + + +class TestSpamChecker: + """A spam checker module that rejects all media that includes the bytes + `evil`. + """ + + def __init__(self, config, api): + self.config = config + self.api = api + + def parse_config(config): + return config + + async def check_event_for_spam(self, foo): + return False # allow all events + + async def user_may_invite(self, inviter_userid, invitee_userid, room_id): + return True # allow all invites + + async def user_may_create_room(self, userid): + return True # allow all room creations + + async def user_may_create_room_alias(self, userid, room_alias): + return True # allow all room aliases + + async def user_may_publish_room(self, userid, room_id): + return True # allow publishing of all rooms + + async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool: + buf = BytesIO() + await file_wrapper.write_chunks_to(buf.write) + + return b"evil" in buf.getvalue() + + +class SpamCheckerTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + + def default_config(self): + config = default_config("test") + + config.update( + { + "spam_checker": [ + { + "module": TestSpamChecker.__module__ + ".TestSpamChecker", + "config": {}, + } + ] + } + ) + + return config + + def test_upload_innocent(self): + """Attempt to upload some innocent data that should be allowed. + """ + + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + self.helper.upload_media( + self.upload_resource, image_data, tok=self.tok, expect_code=200 + ) + + def test_upload_ban(self): + """Attempt to upload some data that includes bytes "evil", which should + get rejected by the spam checker. + """ + + data = b"Some evil data" + + self.helper.upload_media( + self.upload_resource, data, tok=self.tok, expect_code=400 + ) diff --git a/tests/test_preview.py b/tests/test_preview.py index 0c6cbbd921..ea83299918 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -15,6 +15,7 @@ from synapse.rest.media.v1.preview_url_resource import ( decode_and_calc_og, + get_html_media_encoding, summarize_paragraphs, ) @@ -26,7 +27,7 @@ except ImportError: lxml = None -class PreviewTestCase(unittest.TestCase): +class SummarizeTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" @@ -144,12 +145,12 @@ class PreviewTestCase(unittest.TestCase): ) -class PreviewUrlTestCase(unittest.TestCase): +class CalcOgTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" def test_simple(self): - html = """ + html = b""" <html> <head><title>Foo</title></head> <body> @@ -163,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment(self): - html = """ + html = b""" <html> <head><title>Foo</title></head> <body> @@ -178,7 +179,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_comment2(self): - html = """ + html = b""" <html> <head><title>Foo</title></head> <body> @@ -202,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase): ) def test_script(self): - html = """ + html = b""" <html> <head><title>Foo</title></head> <body> @@ -217,7 +218,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_missing_title(self): - html = """ + html = b""" <html> <body> Some text. @@ -230,7 +231,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_h1_as_title(self): - html = """ + html = b""" <html> <meta property="og:description" content="Some text."/> <body> @@ -244,7 +245,7 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) def test_missing_title_and_broken_h1(self): - html = """ + html = b""" <html> <body> <h1><a href="foo"/></h1> @@ -258,13 +259,20 @@ class PreviewUrlTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) def test_empty(self): - html = "" + """Test a body with no data in it.""" + html = b"" + og = decode_and_calc_og(html, "http://example.com/test.html") + self.assertEqual(og, {}) + + def test_no_tree(self): + """A valid body with no tree in it.""" + html = b"\x00" og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEqual(og, {}) def test_invalid_encoding(self): """An invalid character encoding should be ignored and treated as UTF-8, if possible.""" - html = """ + html = b""" <html> <head><title>Foo</title></head> <body> @@ -290,3 +298,76 @@ class PreviewUrlTestCase(unittest.TestCase): """ og = decode_and_calc_og(html, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) + + +class MediaEncodingTestCase(unittest.TestCase): + def test_meta_charset(self): + """A character encoding is found via the meta tag.""" + encoding = get_html_media_encoding( + b""" + <html> + <head><meta charset="ascii"> + </head> + </html> + """, + "text/html", + ) + self.assertEqual(encoding, "ascii") + + # A less well-formed version. + encoding = get_html_media_encoding( + b""" + <html> + <head>< meta charset = ascii> + </head> + </html> + """, + "text/html", + ) + self.assertEqual(encoding, "ascii") + + def test_xml_encoding(self): + """A character encoding is found via the meta tag.""" + encoding = get_html_media_encoding( + b""" + <?xml version="1.0" encoding="ascii"?> + <html> + </html> + """, + "text/html", + ) + self.assertEqual(encoding, "ascii") + + def test_meta_xml_encoding(self): + """Meta tags take precedence over XML encoding.""" + encoding = get_html_media_encoding( + b""" + <?xml version="1.0" encoding="ascii"?> + <html> + <head><meta charset="UTF-16"> + </head> + </html> + """, + "text/html", + ) + self.assertEqual(encoding, "UTF-16") + + def test_content_type(self): + """A character encoding is found via the Content-Type header.""" + # Test a few variations of the header. + headers = ( + 'text/html; charset="ascii";', + "text/html;charset=ascii;", + 'text/html; charset="ascii"', + "text/html; charset=ascii", + 'text/html; charset="ascii;', + 'text/html; charset=ascii";', + ) + for header in headers: + encoding = get_html_media_encoding(b"", header) + self.assertEqual(encoding, "ascii") + + def test_fallback(self): + """A character encoding cannot be found in the body or header.""" + encoding = get_html_media_encoding(b"", "text/html") + self.assertEqual(encoding, "utf-8") |