diff options
author | Will Hunt <will@half-shot.uk> | 2020-09-28 11:11:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-28 11:11:57 +0100 |
commit | 6201ea56eeae22258d58db2f7381f96759f5e85e (patch) | |
tree | 37cb630a46fe96e416110de253ad50d256c0482e | |
parent | Add support for device messages, start support for device lists (diff) | |
parent | typo (diff) | |
download | synapse-6201ea56eeae22258d58db2f7381f96759f5e85e.tar.xz |
Merge branch 'develop' into hs/super-wip-edus-down-sync
86 files changed, 1831 insertions, 496 deletions
diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db index f20567ba73..361369a581 100644 --- a/.buildkite/test_db.db +++ b/.buildkite/test_db.db Binary files differdiff --git a/CHANGES.md b/CHANGES.md index 84976ab2bd..5de819ea1e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,30 @@ +Synapse 1.20.1 (2020-09-24) +=========================== + +Bugfixes +-------- + +- Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386)) +- Fix a bug introduced in v1.20.0 which caused variables to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) + + +Synapse 1.20.0 (2020-09-22) +=========================== + +No significant changes since v1.20.0rc5. + +Removal warning +--------------- + +Historically, the [Synapse Admin +API](https://github.com/matrix-org/synapse/tree/master/docs) has been +accessible under the `/_matrix/client/api/v1/admin`, +`/_matrix/client/unstable/admin`, `/_matrix/client/r0/admin` and +`/_synapse/admin` prefixes. In a future release, we will be dropping support +for accessing Synapse's Admin API using the `/_matrix/client/*` prefixes. This +makes it easier for homeserver admins to lock down external access to the Admin +API endpoints. + Synapse 1.20.0rc5 (2020-09-18) ============================== diff --git a/changelog.d/8217.feature b/changelog.d/8217.feature new file mode 100644 index 0000000000..899cbf14ef --- /dev/null +++ b/changelog.d/8217.feature @@ -0,0 +1 @@ +Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8330.misc b/changelog.d/8330.misc index c51370f215..fbfdd52473 100644 --- a/changelog.d/8330.misc +++ b/changelog.d/8330.misc @@ -1 +1 @@ -Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. \ No newline at end of file +Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. diff --git a/changelog.d/8345.feature b/changelog.d/8345.feature new file mode 100644 index 0000000000..4ee5b6a56e --- /dev/null +++ b/changelog.d/8345.feature @@ -0,0 +1 @@ +Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang. diff --git a/changelog.d/8353.bugfix b/changelog.d/8353.bugfix new file mode 100644 index 0000000000..45fc0adb8d --- /dev/null +++ b/changelog.d/8353.bugfix @@ -0,0 +1 @@ +Don't send push notifications to expired user accounts. diff --git a/changelog.d/8362.bugfix b/changelog.d/8362.bugfix new file mode 100644 index 0000000000..4e50067c87 --- /dev/null +++ b/changelog.d/8362.bugfix @@ -0,0 +1 @@ +Fixed a regression in v1.19.0 with reactivating users through the admin API. diff --git a/changelog.d/8364.bugfix b/changelog.d/8364.bugfix new file mode 100644 index 0000000000..7b82cbc388 --- /dev/null +++ b/changelog.d/8364.bugfix @@ -0,0 +1,2 @@ +Fix a bug where during device registration the length of the device name wasn't +limited. diff --git a/changelog.d/8370.misc b/changelog.d/8370.misc new file mode 100644 index 0000000000..1aaac1e0bf --- /dev/null +++ b/changelog.d/8370.misc @@ -0,0 +1 @@ +Factor out a `_send_dummy_event_for_room` method. diff --git a/changelog.d/8371.misc b/changelog.d/8371.misc new file mode 100644 index 0000000000..6a54a9496a --- /dev/null +++ b/changelog.d/8371.misc @@ -0,0 +1 @@ +Improve logging of state resolution. diff --git a/changelog.d/8372.misc b/changelog.d/8372.misc new file mode 100644 index 0000000000..a56e36de4b --- /dev/null +++ b/changelog.d/8372.misc @@ -0,0 +1 @@ +Add type annotations to `SimpleHttpClient`. diff --git a/changelog.d/8373.bugfix b/changelog.d/8373.bugfix new file mode 100644 index 0000000000..e9d66a2088 --- /dev/null +++ b/changelog.d/8373.bugfix @@ -0,0 +1 @@ +Include `guest_access` in the fields that are checked for null bytes when updating `room_stats_state`. Broke in v1.7.2. \ No newline at end of file diff --git a/changelog.d/8374.bugfix b/changelog.d/8374.bugfix new file mode 100644 index 0000000000..155bc3404f --- /dev/null +++ b/changelog.d/8374.bugfix @@ -0,0 +1 @@ +Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers. diff --git a/changelog.d/8375.doc b/changelog.d/8375.doc new file mode 100644 index 0000000000..d291fb92fa --- /dev/null +++ b/changelog.d/8375.doc @@ -0,0 +1 @@ +Add note to the reverse proxy settings documentation about disabling Apache's mod_security2. Contributed by Julian Fietkau (@jfietkau). diff --git a/changelog.d/8377.misc b/changelog.d/8377.misc new file mode 100644 index 0000000000..fbfdd52473 --- /dev/null +++ b/changelog.d/8377.misc @@ -0,0 +1 @@ +Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. diff --git a/changelog.d/8383.misc b/changelog.d/8383.misc new file mode 100644 index 0000000000..cb8318bf57 --- /dev/null +++ b/changelog.d/8383.misc @@ -0,0 +1 @@ +Refactor ID generators to use `async with` syntax. diff --git a/changelog.d/8385.bugfix b/changelog.d/8385.bugfix new file mode 100644 index 0000000000..c42502a8e0 --- /dev/null +++ b/changelog.d/8385.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite. diff --git a/changelog.d/8386.bugfix b/changelog.d/8386.bugfix new file mode 100644 index 0000000000..24983a1e95 --- /dev/null +++ b/changelog.d/8386.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. diff --git a/changelog.d/8387.feature b/changelog.d/8387.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8387.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/changelog.d/8388.misc b/changelog.d/8388.misc new file mode 100644 index 0000000000..aaaef88b66 --- /dev/null +++ b/changelog.d/8388.misc @@ -0,0 +1 @@ +Add `EventStreamPosition` type. diff --git a/changelog.d/8396.feature b/changelog.d/8396.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8396.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/changelog.d/8398.bugfix b/changelog.d/8398.bugfix new file mode 100644 index 0000000000..e432aeebf1 --- /dev/null +++ b/changelog.d/8398.bugfix @@ -0,0 +1 @@ +Fix "Re-starting finished log context" warning when receiving an event we already had over federation. diff --git a/changelog.d/8405.feature b/changelog.d/8405.feature new file mode 100644 index 0000000000..f3c4a74bc7 --- /dev/null +++ b/changelog.d/8405.feature @@ -0,0 +1 @@ +Consolidate the SSO error template across all configuration. diff --git a/debian/changelog b/debian/changelog index dbf01d6b1e..264ef9ce7c 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,8 +1,18 @@ -matrix-synapse-py3 (1.20.0ubuntu1) UNRELEASED; urgency=medium +matrix-synapse-py3 (1.20.1) stable; urgency=medium + * New synapse release 1.20.1. + + -- Synapse Packaging team <packages@matrix.org> Thu, 24 Sep 2020 16:25:22 +0100 + +matrix-synapse-py3 (1.20.0) stable; urgency=medium + + [ Synapse Packaging team ] + * New synapse release 1.20.0. + + [ Dexter Chua ] * Use Type=notify in systemd service - -- Dexter Chua <dec41@srcf.net> Wed, 26 Aug 2020 12:41:36 +0000 + -- Synapse Packaging team <packages@matrix.org> Tue, 22 Sep 2020 15:19:32 +0100 matrix-synapse-py3 (1.19.3) stable; urgency=medium diff --git a/docs/admin_api/event_reports.rst b/docs/admin_api/event_reports.rst new file mode 100644 index 0000000000..461be01230 --- /dev/null +++ b/docs/admin_api/event_reports.rst @@ -0,0 +1,129 @@ +Show reported events +==================== + +This API returns information about reported events. + +The api is:: + + GET /_synapse/admin/v1/event_reports?from=0&limit=10 + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +It returns a JSON body like the following: + +.. code:: jsonc + + { + "event_reports": [ + { + "content": { + "reason": "foo", + "score": -100 + }, + "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", + "event_json": { + "auth_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M", + "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws" + ], + "content": { + "body": "matrix.org: This Week in Matrix", + "format": "org.matrix.custom.html", + "formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>", + "msgtype": "m.notice" + }, + "depth": 546, + "hashes": { + "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw" + }, + "origin": "matrix.org", + "origin_server_ts": 1592291711430, + "prev_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M" + ], + "prev_state": [], + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "signatures": { + "matrix.org": { + "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg" + } + }, + "type": "m.room.message", + "unsigned": { + "age_ts": 1592291711430, + } + }, + "id": 2, + "reason": "foo", + "received_ts": 1570897107409, + "room_alias": "#alias1:matrix.org", + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "user_id": "@foo:matrix.org" + }, + { + "content": { + "reason": "bar", + "score": -100 + }, + "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I", + "event_json": { + // hidden items + // see above + }, + "id": 3, + "reason": "bar", + "received_ts": 1598889612059, + "room_alias": "#alias2:matrix.org", + "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org", + "sender": "@foobar:matrix.org", + "user_id": "@bar:matrix.org" + } + ], + "next_token": 2, + "total": 4 + } + +To paginate, check for ``next_token`` and if present, call the endpoint again +with ``from`` set to the value of ``next_token``. This will return a new page. + +If the endpoint does not return a ``next_token`` then there are no more +reports to paginate through. + +**URL parameters:** + +- ``limit``: integer - Is optional but is used for pagination, + denoting the maximum number of items to return in this call. Defaults to ``100``. +- ``from``: integer - Is optional but used for pagination, + denoting the offset in the returned results. This should be treated as an opaque value and + not explicitly set to anything other than the return value of ``next_token`` from a previous call. + Defaults to ``0``. +- ``dir``: string - Direction of event report order. Whether to fetch the most recent first (``b``) or the + oldest first (``f``). Defaults to ``b``. +- ``user_id``: string - Is optional and filters to only return users with user IDs that contain this value. + This is the user who reported the event and wrote the reason. +- ``room_id``: string - Is optional and filters to only return rooms with room IDs that contain this value. + +**Response** + +The following fields are returned in the JSON response body: + +- ``id``: integer - ID of event report. +- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent. +- ``room_id``: string - The ID of the room in which the event being reported is located. +- ``event_id``: string - The ID of the reported event. +- ``user_id``: string - This is the user who reported the event and wrote the reason. +- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. +- ``content``: object - Content of reported event. + + - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. + - ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive". + +- ``sender``: string - This is the ID of the user who sent the original message/event that was reported. +- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set. +- ``event_json``: object - Details of the original event that was reported. +- ``next_token``: integer - Indication for pagination. See above. +- ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``). + diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index edd109fa7b..46d8f35771 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -121,6 +121,14 @@ example.com:8448 { **NOTE**: ensure the `nocanon` options are included. +**NOTE 2**: It appears that Synapse is currently incompatible with the ModSecurity module for Apache (`mod_security2`). If you need it enabled for other services on your web server, you can disable it for Synapse's two VirtualHosts by including the following lines before each of the two `</VirtualHost>` above: + +``` +<IfModule security2_module> + SecRuleEngine off +</IfModule> +``` + ### HAProxy ``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index fb04ff283d..845f537795 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1689,6 +1689,11 @@ oidc_config: # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a34bdf1830..684a518b8e 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], + "users": ["shadow_banned"], } @@ -627,6 +628,7 @@ class Porter(object): self.progress.set_state("Setting up sequence generators") await self._setup_state_group_id_seq() await self._setup_user_id_seq() + await self._setup_events_stream_seqs() self.progress.done() except Exception as e: @@ -803,6 +805,29 @@ class Porter(object): return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) + def _setup_events_stream_seqs(self): + def r(txn): + txn.execute("SELECT MAX(stream_ordering) FROM events") + curr_id = txn.fetchone()[0] + if curr_id: + next_id = curr_id + 1 + txn.execute( + "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,) + ) + + txn.execute("SELECT -MIN(stream_ordering) FROM events") + curr_id = txn.fetchone()[0] + if curr_id: + next_id = curr_id + 1 + txn.execute( + "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s", + (next_id,), + ) + + return self.postgres_store.db_pool.runInteraction( + "_setup_events_stream_seqs", r + ) + ############################################## # The following is simply UI stuff diff --git a/setup.py b/setup.py index 54ddec8f9f..926b1bc86f 100755 --- a/setup.py +++ b/setup.py @@ -94,6 +94,22 @@ ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"] # Make `pip install matrix-synapse[all]` install all the optional dependencies. CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS) +# Developer dependencies should not get included in "all". +# +# We pin black so that our tests don't start failing on new releases. +CONDITIONAL_REQUIREMENTS["lint"] = [ + "isort==5.0.3", + "black==19.10b0", + "flake8-comprehensions", + "flake8", +] + +# Dependencies which are exclusively required by unit test code. This is +# NOT a list of all modules that are necessary to run the unit tests. +# Tests assume that all optional dependencies are installed. +# +# parameterized_class decorator was introduced in parameterized 0.7.0 +CONDITIONAL_REQUIREMENTS["test"] = ["mock>=2.0", "parameterized>=0.7.0"] setup( name="matrix-synapse", diff --git a/synapse/__init__.py b/synapse/__init__.py index a95753dcc7..e40b582bd5 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.20.0rc5" +__version__ = "1.20.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 75388643ee..1071a0576e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -218,11 +218,7 @@ class Auth: # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = await self.store.get_expiration_ts_for_user(user_id) - if ( - expiration_ts is not None - and self.clock.time_msec() >= expiration_ts - ): + if await self.store.is_account_expired(user_id, self.clock.time_msec()): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 9523c3a5d9..2c1dae5984 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - info = await self.get_json(uri, {}) + info = await self.get_json(uri) if not _is_valid_3pe_metadata(info): logger.warning( diff --git a/synapse/config/_base.py b/synapse/config/_base.py index bb9bf8598d..05a66841c3 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -194,7 +194,10 @@ class Config: return file_stream.read() def read_templates( - self, filenames: List[str], custom_template_directory: Optional[str] = None, + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, + autoescape: bool = False, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -210,6 +213,9 @@ class Config: custom_template_directory: A directory to try to look for the templates before using the default Synapse template directory instead. + autoescape: Whether to autoescape variables before inserting them into the + template. + Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. @@ -233,7 +239,7 @@ class Config: search_directories.insert(0, custom_template_directory) loader = jinja2.FileSystemLoader(search_directories) - env = jinja2.Environment(loader=loader, autoescape=True) + env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters env.filters.update( diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index e0939bce84..70fc8a2f62 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -56,6 +56,7 @@ class OIDCConfig(Config): self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) @@ -158,6 +159,11 @@ class OIDCConfig(Config): # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 2b03f5ac76..79668a402e 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -45,7 +45,11 @@ _TLS_VERSION_MAP = { class ServerContextFactory(ContextFactory): """Factory for PyOpenSSL SSL contexts that are used to handle incoming - connections.""" + connections. + + TODO: replace this with an implementation of IOpenSSLServerConnectionCreator, + per https://github.com/matrix-org/synapse/issues/1691 + """ def __init__(self, config): # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 42e4087a92..c04ad77cf9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -42,7 +42,6 @@ from synapse.api.errors import ( ) from synapse.logging.context import ( PreserveLoggingContext, - current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -233,8 +232,6 @@ class Keyring: """ try: - ctx = current_context() - # map from server name to a set of outstanding request ids server_to_request_ids = {} @@ -265,12 +262,8 @@ class Keyring: # if there are no more requests for this server, we can drop the lock. if not server_requests: - with PreserveLoggingContext(ctx): - logger.debug("Releasing key lookup lock on %s", server_name) - - # ... but not immediately, as that can cause stack explosions if - # we get a long queue of lookups. - self.clock.call_later(0, drop_server_lock, server_name) + logger.debug("Releasing key lookup lock on %s", server_name) + drop_server_lock(server_name) return res @@ -335,20 +328,32 @@ class Keyring: ) # look for any requests which weren't satisfied - with PreserveLoggingContext(): - for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) + while remaining_requests: + verify_request = remaining_requests.pop() + rq_str = ( + "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, ) + ) + + # If we run the errback immediately, it may cancel our + # loggingcontext while we are still in it, so instead we + # schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we + # has a massive queue of lookups waiting for this server). + self.clock.call_later( + 0, + verify_request.key_ready.errback, + SynapseError( + 401, + "Failed to find any key to satisfy %s" % (rq_str,), + Codes.UNAUTHORIZED, + ), + ) except Exception as err: # we don't really expect to get here, because any errors should already # have been caught and logged. But if we do, let's log the error and make @@ -410,10 +415,23 @@ class Keyring: # key was not valid at this point continue - with PreserveLoggingContext(): - verify_request.key_ready.callback( - (server_name, key_id, fetch_key_result.verify_key) - ) + # we have a valid key for this request. If we run the callback + # immediately, it may cancel our loggingcontext while we are still in + # it, so instead we schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we had + # a massive queue of lookups waiting for this server). + logger.debug( + "Found key %s:%s for %s", + server_name, + key_id, + verify_request.request_name, + ) + self.clock.call_later( + 0, + verify_request.key_ready.callback, + (server_name, key_id, fetch_key_result.verify_key), + ) completed.append(verify_request) break diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 55a9787439..4149520d6c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, RequestSendFailed, @@ -265,6 +266,24 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + def _check_device_name_length(self, name: str): + """ + Checks whether a device name is longer than the maximum allowed length. + + Args: + name: The name of the device. + + Raises: + SynapseError: if the device name is too long. + """ + if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" + % (MAX_DEVICE_DISPLAY_NAME_LEN,), + errcode=Codes.TOO_LARGE, + ) + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): @@ -282,6 +301,9 @@ class DeviceHandler(DeviceWorkerHandler): Returns: str: device id (generated if none was supplied) """ + + self._check_device_name_length(initial_device_display_name) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -397,12 +419,8 @@ class DeviceHandler(DeviceWorkerHandler): # Reject a new displayname which is too long. new_display_name = content.get("display_name") - if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN: - raise SynapseError( - 400, - "Device display name is too long (max %i)" - % (MAX_DEVICE_DISPLAY_NAME_LEN,), - ) + + self._check_device_name_length(new_display_name) try: await self.store.update_device( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ea9264e751..9f773aefa7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( JsonDict, MutableStateMap, + PersistedEventPosition, + RoomStreamToken, StateMap, UserID, get_domain_from_id, @@ -2956,7 +2958,7 @@ class FederationHandler(BaseHandler): ) return result["max_stream_id"] else: - max_stream_id = await self.storage.persistence.persist_events( + max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -2967,12 +2969,12 @@ class FederationHandler(BaseHandler): if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - await self._notify_persisted_event(event, max_stream_id) + await self._notify_persisted_event(event, max_stream_token) - return max_stream_id + return max_stream_token.stream async def _notify_persisted_event( - self, event: EventBase, max_stream_id: int + self, event: EventBase, max_stream_token: RoomStreamToken ) -> None: """Checks to see if notifier/pushers should be notified about the event or not. @@ -2998,9 +3000,11 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return - event_stream_id = event.internal_metadata.stream_ordering + event_pos = PersistedEventPosition( + self._instance_name, event.internal_metadata.stream_ordering + ) self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) async def _clean_room_for_join(self, room_id: str) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a8fe5cf4e2..ee271e85e5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1138,7 +1138,7 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = await self.storage.persistence.persist_event( + event_pos, max_stream_token = await self.storage.persistence.persist_event( event, context=context ) @@ -1149,7 +1149,7 @@ class EventCreationHandler: def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -1161,7 +1161,7 @@ class EventCreationHandler: # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - return event_stream_id + return event_pos.stream async def _bump_active_time(self, user: UserID) -> None: try: @@ -1182,54 +1182,7 @@ class EventCreationHandler: ) for room_id in room_ids: - # For each room we need to find a joined member we can use to send - # the dummy event with. - - latest_event_ids = await self.store.get_prev_events_for_room(room_id) - - members = await self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids - ) - dummy_event_sent = False - for user_id in members: - if not self.hs.is_mine_id(user_id): - continue - requester = create_requester(user_id) - try: - event, context = await self.create_event( - requester, - { - "type": "org.matrix.dummy_event", - "content": {}, - "room_id": room_id, - "sender": user_id, - }, - prev_event_ids=latest_event_ids, - ) - - event.internal_metadata.proactively_send = False - - # Since this is a dummy-event it is OK if it is sent by a - # shadow-banned user. - await self.send_nonmember_event( - requester, - event, - context, - ratelimit=False, - ignore_shadow_ban=True, - ) - dummy_event_sent = True - break - except ConsentNotGivenError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of consent. Will try another user" % (room_id, user_id) - ) - except AuthError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of power. Will try another user" % (room_id, user_id) - ) + dummy_event_sent = await self._send_dummy_event_for_room(room_id) if not dummy_event_sent: # Did not find a valid user in the room, so remove from future attempts @@ -1242,6 +1195,59 @@ class EventCreationHandler: now = self.clock.time_msec() self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now + async def _send_dummy_event_for_room(self, room_id: str) -> bool: + """Attempt to send a dummy event for the given room. + + Args: + room_id: room to try to send an event from + + Returns: + True if a dummy event was successfully sent. False if no user was able + to send an event. + """ + + # For each room we need to find a joined member we can use to send + # the dummy event with. + latest_event_ids = await self.store.get_prev_events_for_room(room_id) + members = await self.state.get_current_users_in_room( + room_id, latest_event_ids=latest_event_ids + ) + for user_id in members: + if not self.hs.is_mine_id(user_id): + continue + requester = create_requester(user_id) + try: + event, context = await self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + + event.internal_metadata.proactively_send = False + + # Since this is a dummy-event it is OK if it is sent by a + # shadow-banned user. + await self.send_nonmember_event( + requester, event, context, ratelimit=False, ignore_shadow_ban=True, + ) + return True + except ConsentNotGivenError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of consent. Will try another user" % (room_id, user_id) + ) + except AuthError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of power. Will try another user" % (room_id, user_id) + ) + return False + def _expire_rooms_to_exclude_from_dummy_event_insertion(self): expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY to_expire = set() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4230dbaf99..0e06e4408d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -114,6 +114,7 @@ class OidcHandler: hs.config.oidc_user_mapping_provider_config ) # type: OidcMappingProvider self._skip_verification = hs.config.oidc_skip_verification # type: bool + self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() @@ -849,7 +850,8 @@ class OidcHandler: If we don't find the user that way, we should register the user, mapping the localpart and the display name from the UserInfo. - If a user already exists with the mxid we've mapped, raise an exception. + If a user already exists with the mxid we've mapped and allow_existing_users + is disabled, raise an exception. Args: userinfo: an object representing the user @@ -905,21 +907,31 @@ class OidcHandler: localpart = map_username_to_mxid_localpart(attributes["localpart"]) - user_id = UserID(localpart, self._hostname) - if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): - # This mxid is taken - raise MappingException( - "mxid '{}' is already taken".format(user_id.to_string()) + user_id = UserID(localpart, self._hostname).to_string() + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + if self._allow_existing_users: + if len(users) == 1: + registered_user_id = next(iter(users)) + elif user_id in users: + registered_user_id = user_id + else: + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, list(users.keys()) + ) + ) + else: + # This mxid is taken + raise MappingException("mxid '{}' is already taken".format(user_id)) + else: + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id, ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9b3a4f638b..e948efef2e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -967,7 +967,7 @@ class SyncHandler: raise NotImplementedError() else: joined_room_ids = await self.get_rooms_for_user_at( - user_id, now_token.room_stream_id + user_id, now_token.room_key ) sync_result_builder = SyncResultBuilder( sync_config, @@ -1916,7 +1916,7 @@ class SyncHandler: raise Exception("Unrecognized rtype: %r", room_builder.rtype) async def get_rooms_for_user_at( - self, user_id: str, stream_ordering: int + self, user_id: str, room_key: RoomStreamToken ) -> FrozenSet[str]: """Get set of joined rooms for a user at the given stream ordering. @@ -1942,15 +1942,15 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. - for room_id, membership_stream_ordering in joined_rooms: - if membership_stream_ordering <= stream_ordering: + for room_id, event_pos in joined_rooms: + if not event_pos.persisted_after(room_key): joined_room_ids.add(room_id) continue logger.info("User joined room after current token: %s", room_id) extrems = await self.store.get_forward_extremeties_for_room( - room_id, stream_ordering + room_id, event_pos.stream ) users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: diff --git a/synapse/http/client.py b/synapse/http/client.py index 13fcab3378..4694adc400 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -17,6 +17,18 @@ import logging import urllib from io import BytesIO +from typing import ( + Any, + BinaryIO, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import treq from canonicaljson import encode_canonical_json @@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone from twisted.web.client import Agent, HTTPConnectionPool, readBody from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import ( @@ -57,6 +70,19 @@ incoming_responses_counter = Counter( "synapse_http_client_responses", "", ["method", "code"] ) +# the type of the headers list, to be passed to the t.w.h.Headers. +# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so +# we simplify. +RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]] + +# the value actually has to be a List, but List is invariant so we can't specify that +# the entries can either be Lists or bytes. +RawHeaderValue = Sequence[Union[str, bytes]] + +# the type of the query params, to be passed into `urlencode` +QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] +QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] + def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): """ @@ -285,13 +311,26 @@ class SimpleHttpClient: ip_blacklist=self._ip_blacklist, ) - async def request(self, method, uri, data=None, headers=None): + async def request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: """ Args: - method (str): HTTP method to use. - uri (str): URI to query. - data (bytes): Data to send in the request body, if applicable. - headers (t.w.http_headers.Headers): Request headers. + method: HTTP method to use. + uri: URI to query. + data: Data to send in the request body, if applicable. + headers: Request headers. + + Returns: + Response object, once the headers have been read. + + Raises: + RequestTimedOutError if the request times out before the headers are read + """ # A small wrapper around self.agent.request() so we can easily attach # counters to it @@ -324,6 +363,8 @@ class SimpleHttpClient: headers=headers, **self._extra_treq_args ) + # we use our own timeout mechanism rather than treq's as a workaround + # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( request_deferred, 60, @@ -353,18 +394,26 @@ class SimpleHttpClient: set_tag("error_reason", e.args[0]) raise - async def post_urlencoded_get_json(self, uri, args={}, headers=None): + async def post_urlencoded_get_json( + self, + uri: str, + args: Mapping[str, Union[str, List[str]]] = {}, + headers: Optional[RawHeaders] = None, + ) -> Any: """ Args: - uri (str): - args (dict[str, str|List[str]]): query params - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: uri to query + args: parameters to be url-encoded in the body + headers: a map from header name to a list of values for that header Returns: - object: parsed json + parsed json Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException: On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -398,19 +447,24 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def post_json_get_json(self, uri, post_json, headers=None): + async def post_json_get_json( + self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None + ) -> Any: """ Args: - uri (str): - post_json (object): - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: URI to query. + post_json: request body, to be encoded as json + headers: a map from header name to a list of values for that header Returns: - object: parsed json + parsed json Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException: On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -440,21 +494,22 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def get_json(self, uri, args={}, headers=None): - """ Gets some json from the given URI. + async def get_json( + self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None, + ) -> Any: + """Gets some json from the given URI. Args: - uri (str): The URI to request, not including query parameters - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + args: A dictionary used to create query string + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the - HTTP body as JSON. + Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -466,22 +521,27 @@ class SimpleHttpClient: body = await self.get_raw(uri, args, headers=headers) return json_decoder.decode(body.decode("utf-8")) - async def put_json(self, uri, json_body, args={}, headers=None): - """ Puts some json to the given URI. + async def put_json( + self, + uri: str, + json_body: Any, + args: QueryParams = {}, + headers: RawHeaders = None, + ) -> Any: + """Puts some json to the given URI. Args: - uri (str): The URI to request, not including query parameters - json_body (dict): The JSON to put in the HTTP body, - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + json_body: The JSON to put in the HTTP body, + args: A dictionary used to create query strings + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the - HTTP body as JSON. + Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -513,21 +573,23 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def get_raw(self, uri, args={}, headers=None): - """ Gets raw text from the given URI. + async def get_raw( + self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None + ) -> bytes: + """Gets raw text from the given URI. Args: - uri (str): The URI to request, not including query parameters - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + args: A dictionary used to create query strings + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the + Succeeds when we get a 2xx HTTP response, with the HTTP body as bytes. Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException on a non-2xx HTTP response. """ if len(args): @@ -552,16 +614,29 @@ class SimpleHttpClient: # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. - async def get_file(self, url, output_stream, max_size=None, headers=None): + async def get_file( + self, + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: """GETs a file from a given URL Args: - url (str): The URL to GET - output_stream (file): File to write the response body to. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + url: The URL to GET + output_stream: File to write the response body to. + headers: A map from header name to a list of values for that header Returns: - A (int,dict,string,int) tuple of the file length, dict of the response + A tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. + + Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + + SynapseError: if the response is not a 2xx, the remote file is too large, or + another exception happens during the download. """ actual_headers = {b"User-Agent": [self.user_agent]} diff --git a/synapse/notifier.py b/synapse/notifier.py index 37546e6c21..b6b231c15d 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -42,7 +42,13 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig -from synapse.types import Collection, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + Collection, + PersistedEventPosition, + RoomStreamToken, + StreamToken, + UserID, +) from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -187,7 +193,7 @@ class Notifier: self.store = hs.get_datastore() self.pending_new_room_events = ( [] - ) # type: List[Tuple[int, EventBase, Collection[UserID]]] + ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -246,8 +252,8 @@ class Notifier: def on_new_room_event( self, event: EventBase, - room_stream_id: int, - max_room_stream_id: int, + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): """ Used by handlers to inform the notifier something has happened @@ -261,16 +267,16 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((room_stream_id, event, extra_users)) - self._notify_pending_new_room_events(max_room_stream_id) + self.pending_new_room_events.append((event_pos, event, extra_users)) + self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id: int): + def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id: The highest stream_id below which all + max_room_stream_token: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events @@ -279,11 +285,9 @@ class Notifier: users = set() # type: Set[UserID] rooms = set() # type: Set[str] - for room_stream_id, event, extra_users in pending: - if room_stream_id > max_room_stream_id: - self.pending_new_room_events.append( - (room_stream_id, event, extra_users) - ) + for event_pos, event, extra_users in pending: + if event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append((event_pos, event, extra_users)) else: if ( event.type == EventTypes.Member @@ -296,33 +300,32 @@ class Notifier: if users or rooms: self.on_new_event( - "room_key", - RoomStreamToken(None, max_room_stream_id), - users=users, - rooms=rooms, + "room_key", max_room_stream_token, users=users, rooms=rooms, ) - self._on_updated_room_token(max_room_stream_id) + self._on_updated_room_token(max_room_stream_token) - def _on_updated_room_token(self, max_room_stream_id: int): + def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken): """Poke services that might care that the room position has been updated. """ # poke any interested application service. run_as_background_process( - "_notify_app_services", self._notify_app_services, max_room_stream_id + "_notify_app_services", self._notify_app_services, max_room_stream_token ) run_as_background_process( - "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id + "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token ) if self.federation_sender: - self.federation_sender.notify_new_events(max_room_stream_id) + self.federation_sender.notify_new_events(max_room_stream_token.stream) - async def _notify_app_services(self, max_room_stream_id: int): + async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: - await self.appservice_handler.notify_interested_services(max_room_stream_id) + await self.appservice_handler.notify_interested_services( + max_room_stream_token.stream + ) except Exception: logger.exception("Error notifying application services of event") @@ -332,9 +335,9 @@ class Notifier: except Exception: logger.exception("Error notifying application services of event") - async def _notify_pusher_pool(self, max_room_stream_id: int): + async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: - await self._pusher_pool.on_new_notifications(max_room_stream_id) + await self._pusher_pool.on_new_notifications(max_room_stream_token.stream) except Exception: logger.exception("Error pusher pool of event") diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index cc839ffce4..76150e117b 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -60,6 +60,8 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + self._account_validity = hs.config.account_validity + # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() @@ -202,6 +204,14 @@ class PusherPool: ) for u in users_affected: + # Don't push if the user account has expired + if self._account_validity.enabled: + expired = await self.store.is_account_expired( + u, self.clock.time_msec() + ) + if expired: + continue + if u in self.pushers: for p in self.pushers[u].values(): p.on_new_notifications(max_stream_id) @@ -222,6 +232,14 @@ class PusherPool: ) for u in users_affected: + # Don't push if the user account has expired + if self._account_validity.enabled: + expired = await self.store.is_account_expired( + u, self.clock.time_msec() + ) + if expired: + continue + if u in self.pushers: for p in self.pushers[u].values(): p.on_new_receipts(min_stream_id, max_stream_id) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 67f019fd22..288631477e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -37,6 +37,9 @@ logger = logging.getLogger(__name__) # installed when that optional dependency requirement is specified. It is passed # to setup() as extras_require in setup.py # +# Note that these both represent runtime dependencies (and the versions +# installed are checked at runtime). +# # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ @@ -92,20 +95,12 @@ CONDITIONAL_REQUIREMENTS = { "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], - # Dependencies which are exclusively required by unit test code. This is - # NOT a list of all modules that are necessary to run the unit tests. - # Tests assume that all optional dependencies are installed. - # - # parameterized_class decorator was introduced in parameterized 0.7.0 - "test": ["mock>=2.0", "parameterized>=0.7.0"], "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], # hiredis is not a *strict* dependency, but it makes things much faster. # (if it is not installed, we fall back to slow code.) "redis": ["txredisapi>=1.4.7", "hiredis"], - # We pin black so that our tests don't start failing on new releases. - "lint": ["isort==5.0.3", "black==19.10b0", "flake8-comprehensions", "flake8"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] @@ -113,7 +108,7 @@ ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): # Exclude systemd as it's a system-based requirement. # Exclude lint as it's a dev-based requirement. - if name not in ["systemd", "lint"]: + if name not in ["systemd"]: ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d25fa49e1a..d0089fe06c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, + stream_name="caches", instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e82b9e386f..55af3d41ea 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.types import UserID +from synapse.types import PersistedEventPosition, RoomStreamToken, UserID from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -151,8 +151,14 @@ class ReplicationDataHandler: extra_users = () # type: Tuple[UserID, ...] if event.type == EventTypes.Member: extra_users = (UserID.from_string(event.state_key),) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event(event, token, max_token, extra_users) + + max_token = RoomStreamToken( + None, self.store.get_room_max_stream_ordering() + ) + event_pos = PersistedEventPosition(instance_name, token) + self.notifier.on_new_room_event( + event, event_pos, max_token, extra_users + ) # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index af8459719a..944bc9c9ca 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -12,7 +12,7 @@ <p> There was an error during authentication: </p> - <div id="errormsg" style="margin:20px 80px">{{ error_description }}</div> + <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div> <p> If you are seeing this page after clicking a link sent to you via email, make sure you only click the confirmation link once, and that you open the diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 4a75c06480..5c5f00b213 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -31,6 +31,7 @@ from synapse.rest.admin.devices import ( DeviceRestServlet, DevicesRestServlet, ) +from synapse.rest.admin.event_reports import EventReportsRestServlet from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -216,6 +217,7 @@ def register_servlets(hs, http_server): DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server) + EventReportsRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py new file mode 100644 index 0000000000..5b8d0594cd --- /dev/null +++ b/synapse/rest/admin/event_reports.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin + +logger = logging.getLogger(__name__) + + +class EventReportsRestServlet(RestServlet): + """ + List all reported events that are known to the homeserver. Results are returned + in a dictionary containing report information. Supports pagination. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/event_reports + returns: + 200 OK with list of reports if success otherwise an error. + + Args: + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `dir` can be used to define the order of results. + The parameter `user_id` can be used to filter by user id. + The parameter `room_id` can be used to filter by room id. + Returns: + A list of reported events and an integer representing the total number of + reported events that exist given this query + """ + + PATTERNS = admin_patterns("/event_reports$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + await assert_requester_is_admin(self.auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + direction = parse_string(request, "dir", default="b") + user_id = parse_string(request, "user_id") + room_id = parse_string(request, "room_id") + + if start < 0: + raise SynapseError( + 400, + "The start parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "The limit parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + event_reports, total = await self.store.get_event_reports_paginate( + start, limit, direction, user_id, room_id + ) + ret = {"event_reports": event_reports, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(event_reports) + + return 200, ret diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 987765e877..dce6c4d168 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url, user): + async def _download_url(self, url: str, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url + url_to_download = url # type: Optional[str] oembed_url = self._get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. @@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: we should calculate a proper expiration based on the # Cache-Control and Expire headers. But for now, assume 1 hour. expires = ONE_HOUR - etag = headers["ETag"][0] if "ETag" in headers else None + etag = ( + headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + ) else: - html_bytes = oembed_result.html.encode("utf-8") # type: ignore + # we can only get here if we did an oembed request and have an oembed_result.html + assert oembed_result.html is not None + assert oembed_url is not None + + html_bytes = oembed_result.html.encode("utf-8") with self.media_storage.store_into_file(file_info) as (f, fname, finish): f.write(html_bytes) await finish() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 56d6afb863..5a5ea39e01 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -25,7 +25,6 @@ from typing import ( Sequence, Set, Union, - cast, overload, ) @@ -42,7 +41,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import Collection, MutableStateMap, StateMap +from synapse.types import Collection, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -472,10 +471,9 @@ class StateResolutionHandler: def __init__(self, hs): self.clock = hs.get_clock() - # dict of set of event_ids -> _StateCacheEntry. - self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock") + # dict of set of event_ids -> _StateCacheEntry. self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, @@ -519,57 +517,28 @@ class StateResolutionHandler: Returns: The resolved state """ - logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) - group_names = frozenset(state_groups_ids.keys()) with (await self.resolve_linearizer.queue(group_names)): - if self._state_cache is not None: - cache = self._state_cache.get(group_names, None) - if cache: - return cache + cache = self._state_cache.get(group_names, None) + if cache: + return cache logger.info( - "Resolving state for %s with %d groups", room_id, len(state_groups_ids) + "Resolving state for %s with groups %s", room_id, list(group_names), ) state_groups_histogram.observe(len(state_groups_ids)) - # start by assuming we won't have any conflicted state, and build up the new - # state map by iterating through the state groups. If we discover a conflict, - # we give up and instead use `resolve_events_with_store`. - # - # XXX: is this actually worthwhile, or should we just let - # resolve_events_with_store do it? - new_state = {} # type: MutableStateMap[str] - conflicted_state = False - for st in state_groups_ids.values(): - for key, e_id in st.items(): - if key in new_state: - conflicted_state = True - break - new_state[key] = e_id - if conflicted_state: - break - - if conflicted_state: - logger.info("Resolving conflicted state for %r", room_id) - with Measure(self.clock, "state._resolve_events"): - # resolve_events_with_store returns a StateMap, but we can - # treat it as a MutableStateMap as it is above. It isn't - # actually mutated anymore (and is frozen in - # _make_state_cache_entry below). - new_state = cast( - MutableStateMap, - await resolve_events_with_store( - self.clock, - room_id, - room_version, - list(state_groups_ids.values()), - event_map=event_map, - state_res_store=state_res_store, - ), - ) + with Measure(self.clock, "state._resolve_events"): + new_state = await resolve_events_with_store( + self.clock, + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ) # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -579,8 +548,7 @@ class StateResolutionHandler: with Measure(self.clock, "state.create_group_ids"): cache = _make_state_cache_entry(new_state, state_groups_ids) - if self._state_cache is not None: - self._state_cache[group_names] = cache + self._state_cache[group_names] = cache return cache diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index ccb3384db9..0cb12f4c61 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,14 +160,20 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, - instance_name="master", + stream_name="caches", + instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c5a36990e4..ef81d73573 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index e4fd979a33..2d151b9134 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -394,7 +394,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -443,7 +443,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index c04374e43d..fdf394c612 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with await self._device_list_id_gen.get_next() as stream_id: + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c8df0bcb3f..22e1ed15d0 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key (dict): the key data """ - with await self._cross_signing_id_gen.get_next() as stream_id: + async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9a80f419e3..18def01f50 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,7 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -156,15 +156,15 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = await self._backfill_id_gen.get_next_mult( + stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = await self._stream_id_gen.get_next_mult( + stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream @@ -1108,6 +1108,10 @@ class PersistEventsStore: def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ + + def str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) else None + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", @@ -1118,8 +1122,8 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), + "display_name": str_or_none(event.content.get("displayname")), + "avatar_url": str_or_none(event.content.get("avatar_url")), } for event in events ], diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index de9e8d1dc6..f95679ebc4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="events", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, ) self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="backfill", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_backfill_stream_seq", positive=False, + writers=hs.config.worker.writers.events, ) else: # We shouldn't be running in worker mode with SQLite, but its useful diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index ccfbb2135e..7218191965 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with await self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c9f655dfb7..dbbb99cb95 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = await self._presence_id_gen.get_next_mult( + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e20a16f907..711d5aa23d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore): Raises: NotFoundError if the rule does not exist. """ - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c388468273..df8609b97b 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 5867d52b62..c10a16ffa3 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -577,7 +577,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - with await self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 675e81fe34..48ce7ecd16 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_expiration_ts_for_user", ) + async def is_account_expired(self, user_id: str, current_ts: int) -> bool: + """ + Returns whether an user account is expired. + + Args: + user_id: The user's ID + current_ts: The current timestamp + + Returns: + Whether the user account has expired + """ + expiration_ts = await self.get_expiration_ts_for_user(user_id) + return expiration_ts is not None and current_ts >= expiration_ts + async def set_account_validity_for_user( self, user_id: str, @@ -379,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -387,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index bd6f9553c6..3c7630857f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, @@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: str = "b", + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + event_reports: json list of event reports + count: total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn(txn): + filters = [] + args = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.reason, + er.content, + events.sender, + room_aliases.room_alias, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN room_aliases + ON room_aliases.room_id = er.room_id + JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + event_reports = self.db_pool.cursor_to_dict(txn) + + if count > 0: + for row in event_reports: + try: + row["content"] = db_to_json(row["content"]) + row["event_json"] = db_to_json(row["event_json"]) + except Exception: + continue + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4fa8767b01..86ffe2479e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set @@ -37,7 +36,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # for rooms the server is participating in. if self._current_state_events_membership_up_to_date: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ else: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN room_memberships AS m USING (room_id, event_id) INNER JOIN events AS e USING (room_id, event_id) @@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + return frozenset( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + for room_id, instance, stream_id in txn + ) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql index 5e29c1da19..ccf287971c 100644 --- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql @@ -13,7 +13,7 @@ * limitations under the License. */ --- room_id and topoligical_ordering are denormalised from the events table in order to +-- room_id and topological_ordering are denormalised from the events table in order to -- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres index 97c1e6a0c5..c31f9af82a 100644 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', ( CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; +-- If the server has never backfilled a room then doing `-MIN(...)` will give +-- a negative result, hence why we do `GREATEST(...)` SELECT setval('events_backfill_stream_seq', ( - SELECT COALESCE(-MIN(stream_ordering), 1) FROM events + SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events )); diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql new file mode 100644 index 0000000000..985fd949a2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d7816a8606..5beb302be3 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -210,6 +210,7 @@ class StatsStore(StateDeltasStore): * topic * avatar * canonical_alias + * guest_access A is_federatable key can also be included with a boolean value. @@ -234,6 +235,7 @@ class StatsStore(StateDeltasStore): "topic", "avatar", "canonical_alias", + "guest_access", ): field = fields.get(col, sentinel) if field is not sentinel and (not isinstance(field, str) or "\0" in field): diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 96ffe26cc9..9f120d3cb6 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 2f7c95fc74..f9575b1f1f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore): return # They are there, delete them. - self.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "erased_users", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index d89f6ed128..603cd7d825 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import Collection, StateMap +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -190,6 +190,7 @@ class EventsPersistenceStorage: self.persist_events_store = stores.persist_events self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() @@ -198,7 +199,7 @@ class EventsPersistenceStorage: self, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> int: + ) -> RoomStreamToken: """ Write events to the database Args: @@ -228,11 +229,11 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return self.main_store.get_current_events_token() + return RoomStreamToken(None, self.main_store.get_current_events_token()) async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[int, int]: + ) -> Tuple[PersistedEventPosition, RoomStreamToken]: """ Returns: The stream ordering of `event`, and the stream ordering of the @@ -247,7 +248,10 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) max_persisted_id = self.main_store.get_current_events_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) + event_stream_id = event.internal_metadata.stream_ordering + + pos = PersistedEventPosition(self._instance_name, event_stream_id) + return pos, RoomStreamToken(None, max_persisted_id) def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8c4a83a840..f152f63321 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -25,7 +25,7 @@ RoomsForUser = namedtuple( ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") + "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1de2b91587..4269eaf918 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -12,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -86,7 +87,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +102,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +114,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +122,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +141,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +150,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -184,12 +185,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +203,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,9 +225,7 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). @@ -251,30 +258,84 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - + ): cur = db_conn.cursor() - cur.execute(sql) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + # Load the current positions of all writers for the stream. + if self._writers: + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + sql = self._db.engine.convert_param_style(sql) - cur.close() + cur.execute(sql, (self._stream_name,)) - return current_positions + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } + + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + # We add a GREATEST here to ensure that the result is always + # positive. (This can be a problem for e.g. backfill streams where + # the server has never backfilled). + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (min_stream_id,)) + + self._persisted_upto_position = min_stream_id + + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) + + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -282,59 +343,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - - self._unfinished_ids.add(next_id) - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) + return _MultiWriterCtxManager(self) - return manager() - - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -352,6 +377,21 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): @@ -482,3 +522,95 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + def _update_stream_positions_table_txn(self, txn): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self.id_gen._lock: + assert max(self.id_gen._current_positions.values(), default=0) < min( + self.stream_ids + ) + + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + ) + + return False diff --git a/synapse/types.py b/synapse/types.py index a6fc7df22c..ec39f9e1e8 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -495,6 +495,21 @@ class StreamToken: StreamToken.START = StreamToken.from_string("s0_0") +@attr.s(slots=True, frozen=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name = attr.ib(type=str) + stream = attr.ib(type=int) + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.stream < self.stream + + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) ): diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2e6e7abf1f..5cf408f21f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -23,6 +23,7 @@ from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer +from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError from synapse.crypto import keyring @@ -33,7 +34,6 @@ from synapse.crypto.keyring import ( ) from synapse.logging.context import ( LoggingContext, - PreserveLoggingContext, current_context, make_deferred_yieldable, ) @@ -68,54 +68,40 @@ class MockPerspectiveServer: class KeyringTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - self.mock_perspective_server = MockPerspectiveServer() - self.http_client = Mock() - - config = self.default_config() - config["trusted_key_servers"] = [ - { - "server_name": self.mock_perspective_server.server_name, - "verify_keys": self.mock_perspective_server.get_verify_keys(), - } - ] - - return self.setup_test_homeserver( - handlers=None, http_client=self.http_client, config=config - ) - - def check_context(self, _, expected): + def check_context(self, val, expected): self.assertEquals(getattr(current_context(), "request", None), expected) + return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - key1 = signedjson.key.generate_signing_key(1) + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock() + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) - kr = keyring.Keyring(self.hs) + # a signed object that we are going to try to validate + key1 = signedjson.key.generate_signing_key(1) json1 = {} signedjson.sign.sign_json(json1, "server10", key1) - persp_resp = { - "server_keys": [ - self.mock_perspective_server.get_signed_key( - "server10", signedjson.key.get_verify_key(key1) - ) - ] - } - persp_deferred = defer.Deferred() + # start off a first set of lookups. We make the mock fetcher block until this + # deferred completes. + first_lookup_deferred = Deferred() + + async def first_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_11") + self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) - async def get_perspectives(**kwargs): - self.assertEquals(current_context().request, "11") - with PreserveLoggingContext(): - await persp_deferred - return persp_resp + await make_deferred_yieldable(first_lookup_deferred) + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } - self.http_client.post_json.side_effect = get_perspectives + mock_fetcher.get_keys.side_effect = first_lookup_fetch - # start off a first set of lookups - @defer.inlineCallbacks - def first_lookup(): - with LoggingContext("11") as context_11: - context_11.request = "11" + async def first_lookup(): + with LoggingContext("context_11") as context_11: + context_11.request = "context_11" res_deferreds = kr.verify_json_objects_for_server( [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] @@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # the unsigned json should be rejected pretty quickly self.assertTrue(res_deferreds[1].called) try: - yield res_deferreds[1] + await res_deferreds[1] self.assertFalse("unsigned json didn't cause a failure") except SynapseError: pass @@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(res_deferreds[0].called) res_deferreds[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds[0]) + await make_deferred_yieldable(res_deferreds[0]) - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) + d0 = ensureDeferred(first_lookup()) - d0 = first_lookup() - - # wait a tick for it to send the request to the perspectives server - # (it first tries the datastore) - self.pump() - self.http_client.post_json.assert_called_once() + mock_fetcher.get_keys.assert_called_once() # a second request for a server with outstanding requests # should block rather than start a second call - @defer.inlineCallbacks - def second_lookup(): - with LoggingContext("12") as context_12: - context_12.request = "12" - self.http_client.post_json.reset_mock() - self.http_client.post_json.return_value = defer.Deferred() + + async def second_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_12") + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } + + mock_fetcher.get_keys.reset_mock() + mock_fetcher.get_keys.side_effect = second_lookup_fetch + second_lookup_state = [0] + + async def second_lookup(): + with LoggingContext("context_12") as context_12: + context_12.request = "context_12" res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 1 + await make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 2 - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) - - d2 = second_lookup() + d2 = ensureDeferred(second_lookup()) self.pump() - self.http_client.post_json.assert_not_called() + # the second request should be pending, but the fetcher should not yet have been + # called + self.assertEqual(second_lookup_state[0], 1) + mock_fetcher.get_keys.assert_not_called() # complete the first request - persp_deferred.callback(persp_resp) + first_lookup_deferred.callback(None) + + # and now both verifications should succeed. self.get_success(d0) self.get_success(d2) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 6aa322bf3a..969d44c787 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase): # These tests assume that it starts 1000 seconds in. self.reactor.advance(1000) + def test_device_is_created_with_invalid_name(self): + self.get_failure( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="foo", + initial_device_display_name="a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1), + ), + synapse.api.errors.SynapseError, + ) + def test_device_is_created_if_doesnt_exist(self): res = self.get_success( self.handler.check_device_registered( diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 89ec5fcb31..5910772aa8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) self.assertEqual(mxid, "@test_user_2:test") + + # Test if the mxid is already taken + store = self.hs.get_datastore() + user3 = UserID.from_string("@test_user_3:test") + self.get_success( + store.register_user(user_id=user3.to_string(), password_hash=None) + ) + userinfo = {"sub": "test3", "username": "test_user_3"} + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + + @override_config({"oidc_config": {"allow_existing_users": True}}) + def test_map_userinfo_to_existing_user(self): + """Existing users can log in with OpenID Connect when allow_existing_users is True.""" + store = self.hs.get_datastore() + user4 = UserID.from_string("@test_user_4:test") + self.get_success( + store.register_user(user_id=user4.to_string(), password_hash=None) + ) + userinfo = { + "sub": "test4", + "username": "test_user_4", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_4:test") diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index bc578411d6..c0ee1cfbd6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_ from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser +from synapse.types import PersistedEventPosition from tests.server import FakeTransport @@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) self.replicate() + + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) self.check( "get_rooms_for_user_with_stream_ordering", (USER_ID_2,), - {(ROOM_ID, j2.internal_metadata.stream_ordering)}, + {(ROOM_ID, expected_pos)}, ) def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): @@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # the membership change is only any use to us if the room is in the # joined_rooms list. if membership_changes: - self.assertEqual( - joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)} + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering ) + self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)}) event_id = 0 diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index faa7f381a9..92c9058887 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. request, channel = self.make_request( diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py new file mode 100644 index 0000000000..bf79086f78 --- /dev/null +++ b/tests/rest/admin/test_event_reports.py @@ -0,0 +1,382 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import report_event + +from tests import unittest + + +class EventReportsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id1 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) + + self.room_id2 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok) + + # Two rooms and two users. Every user sends and reports every room event + for i in range(5): + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.other_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id2, user_tok=self.other_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.admin_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id2, user_tok=self.admin_user_tok, + ) + + self.url = "/_synapse/admin/v1/event_reports" + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self): + """ + Testing list of reported events + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + def test_limit(self): + """ + Testing list of reported events with limit + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["event_reports"]) + + def test_from(self): + """ + Testing list of reported events with a defined starting point (from) + """ + + request, channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + def test_limit_and_from(self): + """ + Testing list of reported events with a defined starting point and limit + """ + + request, channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self._check_fields(channel.json_body["event_reports"]) + + def test_filter_room(self): + """ + Testing list of reported events with a filter of room + """ + + request, channel = self.make_request( + "GET", + self.url + "?room_id=%s" % self.room_id1, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["room_id"], self.room_id1) + + def test_filter_user(self): + """ + Testing list of reported events with a filter of user + """ + + request, channel = self.make_request( + "GET", + self.url + "?user_id=%s" % self.other_user, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["user_id"], self.other_user) + + def test_filter_user_and_room(self): + """ + Testing list of reported events with a filter of user and room + """ + + request, channel = self.make_request( + "GET", + self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 5) + self.assertEqual(len(channel.json_body["event_reports"]), 5) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["user_id"], self.other_user) + self.assertEqual(report["room_id"], self.room_id1) + + def test_valid_search_order(self): + """ + Testing search order. Order by timestamps. + """ + + # fetch the most recent first, largest timestamp + request, channel = self.make_request( + "GET", self.url + "?dir=b", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + report = 1 + while report < len(channel.json_body["event_reports"]): + self.assertGreaterEqual( + channel.json_body["event_reports"][report - 1]["received_ts"], + channel.json_body["event_reports"][report]["received_ts"], + ) + report += 1 + + # fetch the oldest first, smallest timestamp + request, channel = self.make_request( + "GET", self.url + "?dir=f", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + report = 1 + while report < len(channel.json_body["event_reports"]): + self.assertLessEqual( + channel.json_body["event_reports"][report - 1]["received_ts"], + channel.json_body["event_reports"][report]["received_ts"], + ) + report += 1 + + def test_invalid_search_order(self): + """ + Testing that a invalid search order returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual("Unknown direction: bar", channel.json_body["error"]) + + def test_limit_is_negative(self): + """ + Testing that a negative list parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self): + """ + Testing that a negative from parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + # `next_token` does not appear + # Number of results is the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + request, channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _create_event_and_report(self, room_id, user_tok): + """Create and report events + """ + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + request, channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({"score": -100, "reason": "this makes me sad"}), + access_token=user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def _check_fields(self, content): + """Checks that all attributes are present in a event report + """ + for c in content: + self.assertIn("id", c) + self.assertIn("received_ts", c) + self.assertIn("room_id", c) + self.assertIn("event_id", c) + self.assertIn("user_id", c) + self.assertIn("reason", c) + self.assertIn("content", c) + self.assertIn("sender", c) + self.assertIn("room_alias", c) + self.assertIn("event_json", c) + self.assertIn("score", c["content"]) + self.assertIn("reason", c["content"]) + self.assertIn("auth_events", c["event_json"]) + self.assertIn("type", c["event_json"]) + self.assertIn("room_id", c["event_json"]) + self.assertIn("sender", c["event_json"]) + self.assertIn("content", c["event_json"]) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index f96011fc1c..98d0623734 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self._is_erased("@user:test", False) + d = self.store.mark_user_erased("@user:test") + self.assertIsNone(self.get_success(d)) + self._is_erased("@user:test", True) # Attempt to reactivate the user (without a password). request, channel = self.make_request( @@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) + self._is_erased("@user:test", False) def test_set_user_as_admin(self): """ @@ -996,6 +1001,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + def _is_erased(self, user_id, expect): + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 20636fc400..d4ff55fbff 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, ) return self.get_success(self.db_pool.runWithConnection(_create)) @@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", (instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) @@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, stream_id, stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) @@ -111,7 +129,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await id_gen.get_next() as stream_id: + async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) @@ -139,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ctx3 = self.get_success(id_gen.get_next()) ctx4 = self.get_success(id_gen.get_next()) - s1 = ctx1.__enter__() - s2 = ctx2.__enter__() - s3 = ctx3.__enter__() - s4 = ctx4.__enter__() + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + s3 = self.get_success(ctx3.__aenter__()) + s4 = self.get_success(ctx4.__aenter__()) self.assertEqual(s1, 8) self.assertEqual(s2, 9) @@ -152,22 +170,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx2.__exit__(None, None, None) + self.get_success(ctx2.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx1.__exit__(None, None, None) + self.get_success(ctx1.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx4.__exit__(None, None, None) + self.get_success(ctx4.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx3.__exit__(None, None, None) + self.get_success(ctx3.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) @@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_rows("first", 3) self._insert_rows("second", 4) - first_id_gen = self._create_id_generator("first") - second_id_gen = self._create_id_generator("second") + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) @@ -190,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await first_id_gen.get_next() as stream_id: + async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual( @@ -208,7 +226,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # stream ID async def _get_next_async(): - with await second_id_gen.get_next() as stream_id: + async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) self.assertEqual( @@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -300,14 +318,18 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 3) - with self.get_success(id_gen.get_next()) as stream_id: - self.assertEqual(stream_id, 6) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_persisted_upto_position(), 6) @@ -315,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + def test_restart_during_out_of_order_persistence(self): + """Test that restarting a process while another process is writing out + of order updates are handled correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # Persist two rows at once + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # We finish persisting the second row before restart + self.get_success(ctx2.__aexit__(None, None, None)) + + # We simulate a restart of another worker by just creating a new ID gen. + id_gen_worker = self._create_id_generator("worker") + + # Restarted worker should not see the second persisted row + self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) + self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) + + # Now if we persist the first row then both instances should jump ahead + # correctly. + self.get_success(ctx1.__aexit__(None, None, None)) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + id_gen_worker.advance("master", 9) + self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) + + def test_writer_config_change(self): + """Test that changing the writer config correctly works. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + # Initial config has two writers + id_gen = self._create_id_generator("first", writers=["first", "second"]) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # New config removes one of the configs. Note that if the writer is + # removed from config we assume that it has been shut down and has + # finished persisting, hence why the persisted upto position is 5. + id_gen_2 = self._create_id_generator("second", writers=["second"]) + self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + + # This config points to a single, previously unused writer. + id_gen_3 = self._create_id_generator("third", writers=["third"]) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + + # Check that we get a sane next stream ID with this new config. + + async def _get_next_async(): + async with id_gen_3.get_next() as stream_id: + self.assertEqual(stream_id, 6) + + self.get_success(_get_next_async()) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs. @@ -341,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, positive=False, ) @@ -364,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): txn.execute( "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, -stream_id, -stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) @@ -373,16 +480,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ id_gen = self._create_id_generator() - with self.get_success(id_gen.get_next()) as stream_id: - self._insert_row("master", stream_id) + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": -1}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - with self.get_success(id_gen.get_next_mult(3)) as stream_ids: - for stream_id in stream_ids: - self._insert_row("master", stream_id) + async def _get_next_async2(): + async with id_gen.get_next_mult(3) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen.get_positions(), {"master": -4}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) @@ -399,21 +512,27 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests that having multiple instances that get advanced over federation works corretly. """ - id_gen_1 = self._create_id_generator("first") - id_gen_2 = self._create_id_generator("second") + id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) + id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) - with self.get_success(id_gen_1.get_next()) as stream_id: - self._insert_row("first", stream_id) - id_gen_2.advance("first", stream_id) + async def _get_next_async(): + async with id_gen_1.get_next() as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - with self.get_success(id_gen_2.get_next()) as stream_id: - self._insert_row("second", stream_id) - id_gen_1.advance("second", stream_id) + async def _get_next_async2(): + async with id_gen_2.get_next() as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) diff --git a/tox.ini b/tox.ini index ddcab0198f..4d132eff4c 100644 --- a/tox.ini +++ b/tox.ini @@ -2,13 +2,12 @@ envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort [base] +extras = test deps = - mock python-subunit junitxml coverage coverage-enable-subprocess - parameterized # cyptography 2.2 requires setuptools >= 18.5 # @@ -36,7 +35,7 @@ setenv = [testenv] deps = {[base]deps} -extras = all +extras = all, test whitelist_externals = sh @@ -84,7 +83,6 @@ deps = # Old automat version for Twisted Automat == 0.3.0 - mock lxml coverage coverage-enable-subprocess @@ -97,7 +95,7 @@ commands = /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install' # Install Synapse itself. This won't update any libraries. - pip install -e . + pip install -e ".[test]" {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} |