diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db
index f20567ba73..361369a581 100644
--- a/.buildkite/test_db.db
+++ b/.buildkite/test_db.db
Binary files differdiff --git a/CHANGES.md b/CHANGES.md
index 84976ab2bd..5de819ea1e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,30 @@
+Synapse 1.20.1 (2020-09-24)
+===========================
+
+Bugfixes
+--------
+
+- Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386))
+- Fix a bug introduced in v1.20.0 which caused variables to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394))
+
+
+Synapse 1.20.0 (2020-09-22)
+===========================
+
+No significant changes since v1.20.0rc5.
+
+Removal warning
+---------------
+
+Historically, the [Synapse Admin
+API](https://github.com/matrix-org/synapse/tree/master/docs) has been
+accessible under the `/_matrix/client/api/v1/admin`,
+`/_matrix/client/unstable/admin`, `/_matrix/client/r0/admin` and
+`/_synapse/admin` prefixes. In a future release, we will be dropping support
+for accessing Synapse's Admin API using the `/_matrix/client/*` prefixes. This
+makes it easier for homeserver admins to lock down external access to the Admin
+API endpoints.
+
Synapse 1.20.0rc5 (2020-09-18)
==============================
diff --git a/changelog.d/8217.feature b/changelog.d/8217.feature
new file mode 100644
index 0000000000..899cbf14ef
--- /dev/null
+++ b/changelog.d/8217.feature
@@ -0,0 +1 @@
+Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel.
\ No newline at end of file
diff --git a/changelog.d/8330.misc b/changelog.d/8330.misc
index c51370f215..fbfdd52473 100644
--- a/changelog.d/8330.misc
+++ b/changelog.d/8330.misc
@@ -1 +1 @@
-Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this.
\ No newline at end of file
+Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this.
diff --git a/changelog.d/8345.feature b/changelog.d/8345.feature
new file mode 100644
index 0000000000..4ee5b6a56e
--- /dev/null
+++ b/changelog.d/8345.feature
@@ -0,0 +1 @@
+Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang.
diff --git a/changelog.d/8353.bugfix b/changelog.d/8353.bugfix
new file mode 100644
index 0000000000..45fc0adb8d
--- /dev/null
+++ b/changelog.d/8353.bugfix
@@ -0,0 +1 @@
+Don't send push notifications to expired user accounts.
diff --git a/changelog.d/8362.bugfix b/changelog.d/8362.bugfix
new file mode 100644
index 0000000000..4e50067c87
--- /dev/null
+++ b/changelog.d/8362.bugfix
@@ -0,0 +1 @@
+Fixed a regression in v1.19.0 with reactivating users through the admin API.
diff --git a/changelog.d/8364.bugfix b/changelog.d/8364.bugfix
new file mode 100644
index 0000000000..7b82cbc388
--- /dev/null
+++ b/changelog.d/8364.bugfix
@@ -0,0 +1,2 @@
+Fix a bug where during device registration the length of the device name wasn't
+limited.
diff --git a/changelog.d/8370.misc b/changelog.d/8370.misc
new file mode 100644
index 0000000000..1aaac1e0bf
--- /dev/null
+++ b/changelog.d/8370.misc
@@ -0,0 +1 @@
+Factor out a `_send_dummy_event_for_room` method.
diff --git a/changelog.d/8371.misc b/changelog.d/8371.misc
new file mode 100644
index 0000000000..6a54a9496a
--- /dev/null
+++ b/changelog.d/8371.misc
@@ -0,0 +1 @@
+Improve logging of state resolution.
diff --git a/changelog.d/8372.misc b/changelog.d/8372.misc
new file mode 100644
index 0000000000..a56e36de4b
--- /dev/null
+++ b/changelog.d/8372.misc
@@ -0,0 +1 @@
+Add type annotations to `SimpleHttpClient`.
diff --git a/changelog.d/8373.bugfix b/changelog.d/8373.bugfix
new file mode 100644
index 0000000000..e9d66a2088
--- /dev/null
+++ b/changelog.d/8373.bugfix
@@ -0,0 +1 @@
+Include `guest_access` in the fields that are checked for null bytes when updating `room_stats_state`. Broke in v1.7.2.
\ No newline at end of file
diff --git a/changelog.d/8374.bugfix b/changelog.d/8374.bugfix
new file mode 100644
index 0000000000..155bc3404f
--- /dev/null
+++ b/changelog.d/8374.bugfix
@@ -0,0 +1 @@
+Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers.
diff --git a/changelog.d/8375.doc b/changelog.d/8375.doc
new file mode 100644
index 0000000000..d291fb92fa
--- /dev/null
+++ b/changelog.d/8375.doc
@@ -0,0 +1 @@
+Add note to the reverse proxy settings documentation about disabling Apache's mod_security2. Contributed by Julian Fietkau (@jfietkau).
diff --git a/changelog.d/8377.misc b/changelog.d/8377.misc
new file mode 100644
index 0000000000..fbfdd52473
--- /dev/null
+++ b/changelog.d/8377.misc
@@ -0,0 +1 @@
+Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this.
diff --git a/changelog.d/8383.misc b/changelog.d/8383.misc
new file mode 100644
index 0000000000..cb8318bf57
--- /dev/null
+++ b/changelog.d/8383.misc
@@ -0,0 +1 @@
+Refactor ID generators to use `async with` syntax.
diff --git a/changelog.d/8385.bugfix b/changelog.d/8385.bugfix
new file mode 100644
index 0000000000..c42502a8e0
--- /dev/null
+++ b/changelog.d/8385.bugfix
@@ -0,0 +1 @@
+Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite.
diff --git a/changelog.d/8386.bugfix b/changelog.d/8386.bugfix
new file mode 100644
index 0000000000..24983a1e95
--- /dev/null
+++ b/changelog.d/8386.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail.
diff --git a/changelog.d/8387.feature b/changelog.d/8387.feature
new file mode 100644
index 0000000000..b363e929ea
--- /dev/null
+++ b/changelog.d/8387.feature
@@ -0,0 +1 @@
+Add experimental support for sharding event persister.
diff --git a/changelog.d/8388.misc b/changelog.d/8388.misc
new file mode 100644
index 0000000000..aaaef88b66
--- /dev/null
+++ b/changelog.d/8388.misc
@@ -0,0 +1 @@
+Add `EventStreamPosition` type.
diff --git a/changelog.d/8396.feature b/changelog.d/8396.feature
new file mode 100644
index 0000000000..b363e929ea
--- /dev/null
+++ b/changelog.d/8396.feature
@@ -0,0 +1 @@
+Add experimental support for sharding event persister.
diff --git a/changelog.d/8398.bugfix b/changelog.d/8398.bugfix
new file mode 100644
index 0000000000..e432aeebf1
--- /dev/null
+++ b/changelog.d/8398.bugfix
@@ -0,0 +1 @@
+Fix "Re-starting finished log context" warning when receiving an event we already had over federation.
diff --git a/changelog.d/8405.feature b/changelog.d/8405.feature
new file mode 100644
index 0000000000..f3c4a74bc7
--- /dev/null
+++ b/changelog.d/8405.feature
@@ -0,0 +1 @@
+Consolidate the SSO error template across all configuration.
diff --git a/debian/changelog b/debian/changelog
index dbf01d6b1e..264ef9ce7c 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,8 +1,18 @@
-matrix-synapse-py3 (1.20.0ubuntu1) UNRELEASED; urgency=medium
+matrix-synapse-py3 (1.20.1) stable; urgency=medium
+ * New synapse release 1.20.1.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 24 Sep 2020 16:25:22 +0100
+
+matrix-synapse-py3 (1.20.0) stable; urgency=medium
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.20.0.
+
+ [ Dexter Chua ]
* Use Type=notify in systemd service
- -- Dexter Chua <dec41@srcf.net> Wed, 26 Aug 2020 12:41:36 +0000
+ -- Synapse Packaging team <packages@matrix.org> Tue, 22 Sep 2020 15:19:32 +0100
matrix-synapse-py3 (1.19.3) stable; urgency=medium
diff --git a/docs/admin_api/event_reports.rst b/docs/admin_api/event_reports.rst
new file mode 100644
index 0000000000..461be01230
--- /dev/null
+++ b/docs/admin_api/event_reports.rst
@@ -0,0 +1,129 @@
+Show reported events
+====================
+
+This API returns information about reported events.
+
+The api is::
+
+ GET /_synapse/admin/v1/event_reports?from=0&limit=10
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+It returns a JSON body like the following:
+
+.. code:: jsonc
+
+ {
+ "event_reports": [
+ {
+ "content": {
+ "reason": "foo",
+ "score": -100
+ },
+ "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY",
+ "event_json": {
+ "auth_events": [
+ "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M",
+ "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws"
+ ],
+ "content": {
+ "body": "matrix.org: This Week in Matrix",
+ "format": "org.matrix.custom.html",
+ "formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>",
+ "msgtype": "m.notice"
+ },
+ "depth": 546,
+ "hashes": {
+ "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw"
+ },
+ "origin": "matrix.org",
+ "origin_server_ts": 1592291711430,
+ "prev_events": [
+ "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M"
+ ],
+ "prev_state": [],
+ "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org",
+ "sender": "@foobar:matrix.org",
+ "signatures": {
+ "matrix.org": {
+ "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg"
+ }
+ },
+ "type": "m.room.message",
+ "unsigned": {
+ "age_ts": 1592291711430,
+ }
+ },
+ "id": 2,
+ "reason": "foo",
+ "received_ts": 1570897107409,
+ "room_alias": "#alias1:matrix.org",
+ "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org",
+ "sender": "@foobar:matrix.org",
+ "user_id": "@foo:matrix.org"
+ },
+ {
+ "content": {
+ "reason": "bar",
+ "score": -100
+ },
+ "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I",
+ "event_json": {
+ // hidden items
+ // see above
+ },
+ "id": 3,
+ "reason": "bar",
+ "received_ts": 1598889612059,
+ "room_alias": "#alias2:matrix.org",
+ "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org",
+ "sender": "@foobar:matrix.org",
+ "user_id": "@bar:matrix.org"
+ }
+ ],
+ "next_token": 2,
+ "total": 4
+ }
+
+To paginate, check for ``next_token`` and if present, call the endpoint again
+with ``from`` set to the value of ``next_token``. This will return a new page.
+
+If the endpoint does not return a ``next_token`` then there are no more
+reports to paginate through.
+
+**URL parameters:**
+
+- ``limit``: integer - Is optional but is used for pagination,
+ denoting the maximum number of items to return in this call. Defaults to ``100``.
+- ``from``: integer - Is optional but used for pagination,
+ denoting the offset in the returned results. This should be treated as an opaque value and
+ not explicitly set to anything other than the return value of ``next_token`` from a previous call.
+ Defaults to ``0``.
+- ``dir``: string - Direction of event report order. Whether to fetch the most recent first (``b``) or the
+ oldest first (``f``). Defaults to ``b``.
+- ``user_id``: string - Is optional and filters to only return users with user IDs that contain this value.
+ This is the user who reported the event and wrote the reason.
+- ``room_id``: string - Is optional and filters to only return rooms with room IDs that contain this value.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- ``id``: integer - ID of event report.
+- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent.
+- ``room_id``: string - The ID of the room in which the event being reported is located.
+- ``event_id``: string - The ID of the reported event.
+- ``user_id``: string - This is the user who reported the event and wrote the reason.
+- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank.
+- ``content``: object - Content of reported event.
+
+ - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank.
+ - ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive".
+
+- ``sender``: string - This is the ID of the user who sent the original message/event that was reported.
+- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set.
+- ``event_json``: object - Details of the original event that was reported.
+- ``next_token``: integer - Indication for pagination. See above.
+- ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``).
+
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index edd109fa7b..46d8f35771 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -121,6 +121,14 @@ example.com:8448 {
**NOTE**: ensure the `nocanon` options are included.
+**NOTE 2**: It appears that Synapse is currently incompatible with the ModSecurity module for Apache (`mod_security2`). If you need it enabled for other services on your web server, you can disable it for Synapse's two VirtualHosts by including the following lines before each of the two `</VirtualHost>` above:
+
+```
+<IfModule security2_module>
+ SecRuleEngine off
+</IfModule>
+```
+
### HAProxy
```
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index fb04ff283d..845f537795 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1689,6 +1689,11 @@ oidc_config:
#
#skip_verification: true
+ # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
+ # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
+ #
+ #allow_existing_users: true
+
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index a34bdf1830..684a518b8e 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
+ "users": ["shadow_banned"],
}
@@ -627,6 +628,7 @@ class Porter(object):
self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
+ await self._setup_events_stream_seqs()
self.progress.done()
except Exception as e:
@@ -803,6 +805,29 @@ class Porter(object):
return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
+ def _setup_events_stream_seqs(self):
+ def r(txn):
+ txn.execute("SELECT MAX(stream_ordering) FROM events")
+ curr_id = txn.fetchone()[0]
+ if curr_id:
+ next_id = curr_id + 1
+ txn.execute(
+ "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,)
+ )
+
+ txn.execute("SELECT -MIN(stream_ordering) FROM events")
+ curr_id = txn.fetchone()[0]
+ if curr_id:
+ next_id = curr_id + 1
+ txn.execute(
+ "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s",
+ (next_id,),
+ )
+
+ return self.postgres_store.db_pool.runInteraction(
+ "_setup_events_stream_seqs", r
+ )
+
##############################################
# The following is simply UI stuff
diff --git a/setup.py b/setup.py
index 54ddec8f9f..926b1bc86f 100755
--- a/setup.py
+++ b/setup.py
@@ -94,6 +94,22 @@ ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"]
# Make `pip install matrix-synapse[all]` install all the optional dependencies.
CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
+# Developer dependencies should not get included in "all".
+#
+# We pin black so that our tests don't start failing on new releases.
+CONDITIONAL_REQUIREMENTS["lint"] = [
+ "isort==5.0.3",
+ "black==19.10b0",
+ "flake8-comprehensions",
+ "flake8",
+]
+
+# Dependencies which are exclusively required by unit test code. This is
+# NOT a list of all modules that are necessary to run the unit tests.
+# Tests assume that all optional dependencies are installed.
+#
+# parameterized_class decorator was introduced in parameterized 0.7.0
+CONDITIONAL_REQUIREMENTS["test"] = ["mock>=2.0", "parameterized>=0.7.0"]
setup(
name="matrix-synapse",
diff --git a/synapse/__init__.py b/synapse/__init__.py
index a95753dcc7..e40b582bd5 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.20.0rc5"
+__version__ = "1.20.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 75388643ee..1071a0576e 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -218,11 +218,7 @@ class Auth:
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
- expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
- if (
- expiration_ts is not None
- and self.clock.time_msec() >= expiration_ts
- ):
+ if await self.store.is_account_expired(user_id, self.clock.time_msec()):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 9523c3a5d9..2c1dae5984 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- info = await self.get_json(uri, {})
+ info = await self.get_json(uri)
if not _is_valid_3pe_metadata(info):
logger.warning(
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index bb9bf8598d..05a66841c3 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -194,7 +194,10 @@ class Config:
return file_stream.read()
def read_templates(
- self, filenames: List[str], custom_template_directory: Optional[str] = None,
+ self,
+ filenames: List[str],
+ custom_template_directory: Optional[str] = None,
+ autoescape: bool = False,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
@@ -210,6 +213,9 @@ class Config:
custom_template_directory: A directory to try to look for the templates
before using the default Synapse template directory instead.
+ autoescape: Whether to autoescape variables before inserting them into the
+ template.
+
Raises:
ConfigError: if the file's path is incorrect or otherwise cannot be read.
@@ -233,7 +239,7 @@ class Config:
search_directories.insert(0, custom_template_directory)
loader = jinja2.FileSystemLoader(search_directories)
- env = jinja2.Environment(loader=loader, autoescape=True)
+ env = jinja2.Environment(loader=loader, autoescape=autoescape)
# Update the environment with our custom filters
env.filters.update(
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index e0939bce84..70fc8a2f62 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -56,6 +56,7 @@ class OIDCConfig(Config):
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+ self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
@@ -158,6 +159,11 @@ class OIDCConfig(Config):
#
#skip_verification: true
+ # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
+ # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
+ #
+ #allow_existing_users: true
+
# An external module can be provided here as a custom solution to mapping
# attributes returned from a OIDC provider onto a matrix user.
#
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 2b03f5ac76..79668a402e 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -45,7 +45,11 @@ _TLS_VERSION_MAP = {
class ServerContextFactory(ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
- connections."""
+ connections.
+
+ TODO: replace this with an implementation of IOpenSSLServerConnectionCreator,
+ per https://github.com/matrix-org/synapse/issues/1691
+ """
def __init__(self, config):
# TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 42e4087a92..c04ad77cf9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -42,7 +42,6 @@ from synapse.api.errors import (
)
from synapse.logging.context import (
PreserveLoggingContext,
- current_context,
make_deferred_yieldable,
preserve_fn,
run_in_background,
@@ -233,8 +232,6 @@ class Keyring:
"""
try:
- ctx = current_context()
-
# map from server name to a set of outstanding request ids
server_to_request_ids = {}
@@ -265,12 +262,8 @@ class Keyring:
# if there are no more requests for this server, we can drop the lock.
if not server_requests:
- with PreserveLoggingContext(ctx):
- logger.debug("Releasing key lookup lock on %s", server_name)
-
- # ... but not immediately, as that can cause stack explosions if
- # we get a long queue of lookups.
- self.clock.call_later(0, drop_server_lock, server_name)
+ logger.debug("Releasing key lookup lock on %s", server_name)
+ drop_server_lock(server_name)
return res
@@ -335,20 +328,32 @@ class Keyring:
)
# look for any requests which weren't satisfied
- with PreserveLoggingContext():
- for verify_request in remaining_requests:
- verify_request.key_ready.errback(
- SynapseError(
- 401,
- "No key for %s with ids in %s (min_validity %i)"
- % (
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
- ),
- Codes.UNAUTHORIZED,
- )
+ while remaining_requests:
+ verify_request = remaining_requests.pop()
+ rq_str = (
+ "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
)
+ )
+
+ # If we run the errback immediately, it may cancel our
+ # loggingcontext while we are still in it, so instead we
+ # schedule it for the next time round the reactor.
+ #
+ # (this also ensures that we don't get a stack overflow if we
+ # has a massive queue of lookups waiting for this server).
+ self.clock.call_later(
+ 0,
+ verify_request.key_ready.errback,
+ SynapseError(
+ 401,
+ "Failed to find any key to satisfy %s" % (rq_str,),
+ Codes.UNAUTHORIZED,
+ ),
+ )
except Exception as err:
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
@@ -410,10 +415,23 @@ class Keyring:
# key was not valid at this point
continue
- with PreserveLoggingContext():
- verify_request.key_ready.callback(
- (server_name, key_id, fetch_key_result.verify_key)
- )
+ # we have a valid key for this request. If we run the callback
+ # immediately, it may cancel our loggingcontext while we are still in
+ # it, so instead we schedule it for the next time round the reactor.
+ #
+ # (this also ensures that we don't get a stack overflow if we had
+ # a massive queue of lookups waiting for this server).
+ logger.debug(
+ "Found key %s:%s for %s",
+ server_name,
+ key_id,
+ verify_request.request_name,
+ )
+ self.clock.call_later(
+ 0,
+ verify_request.key_ready.callback,
+ (server_name, key_id, fetch_key_result.verify_key),
+ )
completed.append(verify_request)
break
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 55a9787439..4149520d6c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.api.errors import (
+ Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
@@ -265,6 +266,24 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
+ def _check_device_name_length(self, name: str):
+ """
+ Checks whether a device name is longer than the maximum allowed length.
+
+ Args:
+ name: The name of the device.
+
+ Raises:
+ SynapseError: if the device name is too long.
+ """
+ if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN:
+ raise SynapseError(
+ 400,
+ "Device display name is too long (max %i)"
+ % (MAX_DEVICE_DISPLAY_NAME_LEN,),
+ errcode=Codes.TOO_LARGE,
+ )
+
async def check_device_registered(
self, user_id, device_id, initial_device_display_name=None
):
@@ -282,6 +301,9 @@ class DeviceHandler(DeviceWorkerHandler):
Returns:
str: device id (generated if none was supplied)
"""
+
+ self._check_device_name_length(initial_device_display_name)
+
if device_id is not None:
new_device = await self.store.store_device(
user_id=user_id,
@@ -397,12 +419,8 @@ class DeviceHandler(DeviceWorkerHandler):
# Reject a new displayname which is too long.
new_display_name = content.get("display_name")
- if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN:
- raise SynapseError(
- 400,
- "Device display name is too long (max %i)"
- % (MAX_DEVICE_DISPLAY_NAME_LEN,),
- )
+
+ self._check_device_name_length(new_display_name)
try:
await self.store.update_device(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ea9264e751..9f773aefa7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
JsonDict,
MutableStateMap,
+ PersistedEventPosition,
+ RoomStreamToken,
StateMap,
UserID,
get_domain_from_id,
@@ -2956,7 +2958,7 @@ class FederationHandler(BaseHandler):
)
return result["max_stream_id"]
else:
- max_stream_id = await self.storage.persistence.persist_events(
+ max_stream_token = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
@@ -2967,12 +2969,12 @@ class FederationHandler(BaseHandler):
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
- await self._notify_persisted_event(event, max_stream_id)
+ await self._notify_persisted_event(event, max_stream_token)
- return max_stream_id
+ return max_stream_token.stream
async def _notify_persisted_event(
- self, event: EventBase, max_stream_id: int
+ self, event: EventBase, max_stream_token: RoomStreamToken
) -> None:
"""Checks to see if notifier/pushers should be notified about the
event or not.
@@ -2998,9 +3000,11 @@ class FederationHandler(BaseHandler):
elif event.internal_metadata.is_outlier():
return
- event_stream_id = event.internal_metadata.stream_ordering
+ event_pos = PersistedEventPosition(
+ self._instance_name, event.internal_metadata.stream_ordering
+ )
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
+ event, event_pos, max_stream_token, extra_users=extra_users
)
async def _clean_room_for_join(self, room_id: str) -> None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a8fe5cf4e2..ee271e85e5 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1138,7 +1138,7 @@ class EventCreationHandler:
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
- event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
+ event_pos, max_stream_token = await self.storage.persistence.persist_event(
event, context=context
)
@@ -1149,7 +1149,7 @@ class EventCreationHandler:
def _notify():
try:
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
+ event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
@@ -1161,7 +1161,7 @@ class EventCreationHandler:
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
- return event_stream_id
+ return event_pos.stream
async def _bump_active_time(self, user: UserID) -> None:
try:
@@ -1182,54 +1182,7 @@ class EventCreationHandler:
)
for room_id in room_ids:
- # For each room we need to find a joined member we can use to send
- # the dummy event with.
-
- latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-
- members = await self.state.get_current_users_in_room(
- room_id, latest_event_ids=latest_event_ids
- )
- dummy_event_sent = False
- for user_id in members:
- if not self.hs.is_mine_id(user_id):
- continue
- requester = create_requester(user_id)
- try:
- event, context = await self.create_event(
- requester,
- {
- "type": "org.matrix.dummy_event",
- "content": {},
- "room_id": room_id,
- "sender": user_id,
- },
- prev_event_ids=latest_event_ids,
- )
-
- event.internal_metadata.proactively_send = False
-
- # Since this is a dummy-event it is OK if it is sent by a
- # shadow-banned user.
- await self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=False,
- ignore_shadow_ban=True,
- )
- dummy_event_sent = True
- break
- except ConsentNotGivenError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of consent. Will try another user" % (room_id, user_id)
- )
- except AuthError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of power. Will try another user" % (room_id, user_id)
- )
+ dummy_event_sent = await self._send_dummy_event_for_room(room_id)
if not dummy_event_sent:
# Did not find a valid user in the room, so remove from future attempts
@@ -1242,6 +1195,59 @@ class EventCreationHandler:
now = self.clock.time_msec()
self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
+ async def _send_dummy_event_for_room(self, room_id: str) -> bool:
+ """Attempt to send a dummy event for the given room.
+
+ Args:
+ room_id: room to try to send an event from
+
+ Returns:
+ True if a dummy event was successfully sent. False if no user was able
+ to send an event.
+ """
+
+ # For each room we need to find a joined member we can use to send
+ # the dummy event with.
+ latest_event_ids = await self.store.get_prev_events_for_room(room_id)
+ members = await self.state.get_current_users_in_room(
+ room_id, latest_event_ids=latest_event_ids
+ )
+ for user_id in members:
+ if not self.hs.is_mine_id(user_id):
+ continue
+ requester = create_requester(user_id)
+ try:
+ event, context = await self.create_event(
+ requester,
+ {
+ "type": "org.matrix.dummy_event",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+
+ event.internal_metadata.proactively_send = False
+
+ # Since this is a dummy-event it is OK if it is sent by a
+ # shadow-banned user.
+ await self.send_nonmember_event(
+ requester, event, context, ratelimit=False, ignore_shadow_ban=True,
+ )
+ return True
+ except ConsentNotGivenError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of consent. Will try another user" % (room_id, user_id)
+ )
+ except AuthError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of power. Will try another user" % (room_id, user_id)
+ )
+ return False
+
def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 4230dbaf99..0e06e4408d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -114,6 +114,7 @@ class OidcHandler:
hs.config.oidc_user_mapping_provider_config
) # type: OidcMappingProvider
self._skip_verification = hs.config.oidc_skip_verification # type: bool
+ self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
@@ -849,7 +850,8 @@ class OidcHandler:
If we don't find the user that way, we should register the user,
mapping the localpart and the display name from the UserInfo.
- If a user already exists with the mxid we've mapped, raise an exception.
+ If a user already exists with the mxid we've mapped and allow_existing_users
+ is disabled, raise an exception.
Args:
userinfo: an object representing the user
@@ -905,21 +907,31 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"])
- user_id = UserID(localpart, self._hostname)
- if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
- # This mxid is taken
- raise MappingException(
- "mxid '{}' is already taken".format(user_id.to_string())
+ user_id = UserID(localpart, self._hostname).to_string()
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ if self._allow_existing_users:
+ if len(users) == 1:
+ registered_user_id = next(iter(users))
+ elif user_id in users:
+ registered_user_id = user_id
+ else:
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, list(users.keys())
+ )
+ )
+ else:
+ # This mxid is taken
+ raise MappingException("mxid '{}' is already taken".format(user_id))
+ else:
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=attributes["display_name"],
+ user_agent_ips=(user_agent, ip_address),
)
-
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
-
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id,
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9b3a4f638b..e948efef2e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -967,7 +967,7 @@ class SyncHandler:
raise NotImplementedError()
else:
joined_room_ids = await self.get_rooms_for_user_at(
- user_id, now_token.room_stream_id
+ user_id, now_token.room_key
)
sync_result_builder = SyncResultBuilder(
sync_config,
@@ -1916,7 +1916,7 @@ class SyncHandler:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
async def get_rooms_for_user_at(
- self, user_id: str, stream_ordering: int
+ self, user_id: str, room_key: RoomStreamToken
) -> FrozenSet[str]:
"""Get set of joined rooms for a user at the given stream ordering.
@@ -1942,15 +1942,15 @@ class SyncHandler:
# If the membership's stream ordering is after the given stream
# ordering, we need to go and work out if the user was in the room
# before.
- for room_id, membership_stream_ordering in joined_rooms:
- if membership_stream_ordering <= stream_ordering:
+ for room_id, event_pos in joined_rooms:
+ if not event_pos.persisted_after(room_key):
joined_room_ids.add(room_id)
continue
logger.info("User joined room after current token: %s", room_id)
extrems = await self.store.get_forward_extremeties_for_room(
- room_id, stream_ordering
+ room_id, event_pos.stream
)
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 13fcab3378..4694adc400 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,6 +17,18 @@
import logging
import urllib
from io import BytesIO
+from typing import (
+ Any,
+ BinaryIO,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
import treq
from canonicaljson import encode_canonical_json
@@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import (
@@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
"synapse_http_client_responses", "", ["method", "code"]
)
+# the type of the headers list, to be passed to the t.w.h.Headers.
+# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
+# we simplify.
+RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
+
+# the value actually has to be a List, but List is invariant so we can't specify that
+# the entries can either be Lists or bytes.
+RawHeaderValue = Sequence[Union[str, bytes]]
+
+# the type of the query params, to be passed into `urlencode`
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
"""
@@ -285,13 +311,26 @@ class SimpleHttpClient:
ip_blacklist=self._ip_blacklist,
)
- async def request(self, method, uri, data=None, headers=None):
+ async def request(
+ self,
+ method: str,
+ uri: str,
+ data: Optional[bytes] = None,
+ headers: Optional[Headers] = None,
+ ) -> IResponse:
"""
Args:
- method (str): HTTP method to use.
- uri (str): URI to query.
- data (bytes): Data to send in the request body, if applicable.
- headers (t.w.http_headers.Headers): Request headers.
+ method: HTTP method to use.
+ uri: URI to query.
+ data: Data to send in the request body, if applicable.
+ headers: Request headers.
+
+ Returns:
+ Response object, once the headers have been read.
+
+ Raises:
+ RequestTimedOutError if the request times out before the headers are read
+
"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
@@ -324,6 +363,8 @@ class SimpleHttpClient:
headers=headers,
**self._extra_treq_args
)
+ # we use our own timeout mechanism rather than treq's as a workaround
+ # for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
request_deferred,
60,
@@ -353,18 +394,26 @@ class SimpleHttpClient:
set_tag("error_reason", e.args[0])
raise
- async def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ async def post_urlencoded_get_json(
+ self,
+ uri: str,
+ args: Mapping[str, Union[str, List[str]]] = {},
+ headers: Optional[RawHeaders] = None,
+ ) -> Any:
"""
Args:
- uri (str):
- args (dict[str, str|List[str]]): query params
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: uri to query
+ args: parameters to be url-encoded in the body
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -398,19 +447,24 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def post_json_get_json(self, uri, post_json, headers=None):
+ async def post_json_get_json(
+ self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
+ ) -> Any:
"""
Args:
- uri (str):
- post_json (object):
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: URI to query.
+ post_json: request body, to be encoded as json
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -440,21 +494,22 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_json(self, uri, args={}, headers=None):
- """ Gets some json from the given URI.
+ async def get_json(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ ) -> Any:
+ """Gets some json from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query string
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -466,22 +521,27 @@ class SimpleHttpClient:
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
- async def put_json(self, uri, json_body, args={}, headers=None):
- """ Puts some json to the given URI.
+ async def put_json(
+ self,
+ uri: str,
+ json_body: Any,
+ args: QueryParams = {},
+ headers: RawHeaders = None,
+ ) -> Any:
+ """Puts some json to the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- json_body (dict): The JSON to put in the HTTP body,
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ json_body: The JSON to put in the HTTP body,
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -513,21 +573,23 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_raw(self, uri, args={}, headers=None):
- """ Gets raw text from the given URI.
+ async def get_raw(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ ) -> bytes:
+ """Gets raw text from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get a 2xx HTTP response, with the
HTTP body as bytes.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
@@ -552,16 +614,29 @@ class SimpleHttpClient:
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
- async def get_file(self, url, output_stream, max_size=None, headers=None):
+ async def get_file(
+ self,
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
"""GETs a file from a given URL
Args:
- url (str): The URL to GET
- output_stream (file): File to write the response body to.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ url: The URL to GET
+ output_stream: File to write the response body to.
+ headers: A map from header name to a list of values for that header
Returns:
- A (int,dict,string,int) tuple of the file length, dict of the response
+ A tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
+
+ Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
+ SynapseError: if the response is not a 2xx, the remote file is too large, or
+ another exception happens during the download.
"""
actual_headers = {b"User-Agent": [self.user_agent]}
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 37546e6c21..b6b231c15d 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -42,7 +42,13 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig
-from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ Collection,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StreamToken,
+ UserID,
+)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -187,7 +193,7 @@ class Notifier:
self.store = hs.get_datastore()
self.pending_new_room_events = (
[]
- ) # type: List[Tuple[int, EventBase, Collection[UserID]]]
+ ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -246,8 +252,8 @@ class Notifier:
def on_new_room_event(
self,
event: EventBase,
- room_stream_id: int,
- max_room_stream_id: int,
+ event_pos: PersistedEventPosition,
+ max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
""" Used by handlers to inform the notifier something has happened
@@ -261,16 +267,16 @@ class Notifier:
until all previous events have been persisted before notifying
the client streams.
"""
- self.pending_new_room_events.append((room_stream_id, event, extra_users))
- self._notify_pending_new_room_events(max_room_stream_id)
+ self.pending_new_room_events.append((event_pos, event, extra_users))
+ self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication()
- def _notify_pending_new_room_events(self, max_room_stream_id: int):
+ def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
- max_room_stream_id: The highest stream_id below which all
+ max_room_stream_token: The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
@@ -279,11 +285,9 @@ class Notifier:
users = set() # type: Set[UserID]
rooms = set() # type: Set[str]
- for room_stream_id, event, extra_users in pending:
- if room_stream_id > max_room_stream_id:
- self.pending_new_room_events.append(
- (room_stream_id, event, extra_users)
- )
+ for event_pos, event, extra_users in pending:
+ if event_pos.persisted_after(max_room_stream_token):
+ self.pending_new_room_events.append((event_pos, event, extra_users))
else:
if (
event.type == EventTypes.Member
@@ -296,33 +300,32 @@ class Notifier:
if users or rooms:
self.on_new_event(
- "room_key",
- RoomStreamToken(None, max_room_stream_id),
- users=users,
- rooms=rooms,
+ "room_key", max_room_stream_token, users=users, rooms=rooms,
)
- self._on_updated_room_token(max_room_stream_id)
+ self._on_updated_room_token(max_room_stream_token)
- def _on_updated_room_token(self, max_room_stream_id: int):
+ def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
"""Poke services that might care that the room position has been
updated.
"""
# poke any interested application service.
run_as_background_process(
- "_notify_app_services", self._notify_app_services, max_room_stream_id
+ "_notify_app_services", self._notify_app_services, max_room_stream_token
)
run_as_background_process(
- "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
+ "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token
)
if self.federation_sender:
- self.federation_sender.notify_new_events(max_room_stream_id)
+ self.federation_sender.notify_new_events(max_room_stream_token.stream)
- async def _notify_app_services(self, max_room_stream_id: int):
+ async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
try:
- await self.appservice_handler.notify_interested_services(max_room_stream_id)
+ await self.appservice_handler.notify_interested_services(
+ max_room_stream_token.stream
+ )
except Exception:
logger.exception("Error notifying application services of event")
@@ -332,9 +335,9 @@ class Notifier:
except Exception:
logger.exception("Error notifying application services of event")
- async def _notify_pusher_pool(self, max_room_stream_id: int):
+ async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
- await self._pusher_pool.on_new_notifications(max_room_stream_id)
+ await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
except Exception:
logger.exception("Error pusher pool of event")
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index cc839ffce4..76150e117b 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -60,6 +60,8 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
+ self._account_validity = hs.config.account_validity
+
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
@@ -202,6 +204,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_notifications(max_stream_id)
@@ -222,6 +232,14 @@ class PusherPool:
)
for u in users_affected:
+ # Don't push if the user account has expired
+ if self._account_validity.enabled:
+ expired = await self.store.is_account_expired(
+ u, self.clock.time_msec()
+ )
+ if expired:
+ continue
+
if u in self.pushers:
for p in self.pushers[u].values():
p.on_new_receipts(min_stream_id, max_stream_id)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 67f019fd22..288631477e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -37,6 +37,9 @@ logger = logging.getLogger(__name__)
# installed when that optional dependency requirement is specified. It is passed
# to setup() as extras_require in setup.py
#
+# Note that these both represent runtime dependencies (and the versions
+# installed are checked at runtime).
+#
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
@@ -92,20 +95,12 @@ CONDITIONAL_REQUIREMENTS = {
"oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
- # Dependencies which are exclusively required by unit test code. This is
- # NOT a list of all modules that are necessary to run the unit tests.
- # Tests assume that all optional dependencies are installed.
- #
- # parameterized_class decorator was introduced in parameterized 0.7.0
- "test": ["mock>=2.0", "parameterized>=0.7.0"],
"sentry": ["sentry-sdk>=0.7.2"],
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
# hiredis is not a *strict* dependency, but it makes things much faster.
# (if it is not installed, we fall back to slow code.)
"redis": ["txredisapi>=1.4.7", "hiredis"],
- # We pin black so that our tests don't start failing on new releases.
- "lint": ["isort==5.0.3", "black==19.10b0", "flake8-comprehensions", "flake8"],
}
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
@@ -113,7 +108,7 @@ ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
# Exclude systemd as it's a system-based requirement.
# Exclude lint as it's a dev-based requirement.
- if name not in ["systemd", "lint"]:
+ if name not in ["systemd"]:
ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d25fa49e1a..d0089fe06c 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
+ stream_name="caches",
instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
+ writers=[],
) # type: Optional[MultiWriterIdGenerator]
else:
self._cache_id_gen = None
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e82b9e386f..55af3d41ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
EventsStreamRow,
)
-from synapse.types import UserID
+from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -151,8 +151,14 @@ class ReplicationDataHandler:
extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ max_token = RoomStreamToken(
+ None, self.store.get_room_max_stream_ordering()
+ )
+ event_pos = PersistedEventPosition(instance_name, token)
+ self.notifier.on_new_room_event(
+ event, event_pos, max_token, extra_users
+ )
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index af8459719a..944bc9c9ca 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -12,7 +12,7 @@
<p>
There was an error during authentication:
</p>
- <div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
+ <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 4a75c06480..5c5f00b213 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -31,6 +31,7 @@ from synapse.rest.admin.devices import (
DeviceRestServlet,
DevicesRestServlet,
)
+from synapse.rest.admin.event_reports import EventReportsRestServlet
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
@@ -216,6 +217,7 @@ def register_servlets(hs, http_server):
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
+ EventReportsRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
new file mode 100644
index 0000000000..5b8d0594cd
--- /dev/null
+++ b/synapse/rest/admin/event_reports.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+
+logger = logging.getLogger(__name__)
+
+
+class EventReportsRestServlet(RestServlet):
+ """
+ List all reported events that are known to the homeserver. Results are returned
+ in a dictionary containing report information. Supports pagination.
+ The requester must have administrator access in Synapse.
+
+ GET /_synapse/admin/v1/event_reports
+ returns:
+ 200 OK with list of reports if success otherwise an error.
+
+ Args:
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ The parameter `dir` can be used to define the order of results.
+ The parameter `user_id` can be used to filter by user id.
+ The parameter `room_id` can be used to filter by room id.
+ Returns:
+ A list of reported events and an integer representing the total number of
+ reported events that exist given this query
+ """
+
+ PATTERNS = admin_patterns("/event_reports$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ direction = parse_string(request, "dir", default="b")
+ user_id = parse_string(request, "user_id")
+ room_id = parse_string(request, "room_id")
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "The start parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "The limit parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ event_reports, total = await self.store.get_event_reports_paginate(
+ start, limit, direction, user_id, room_id
+ )
+ ret = {"event_reports": event_reports, "total": total}
+ if (start + limit) < total:
+ ret["next_token"] = start + len(event_reports)
+
+ return 200, ret
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 987765e877..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url, user):
+ async def _download_url(self, url: str, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
- url_to_download = url
+ url_to_download = url # type: Optional[str]
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
- etag = headers["ETag"][0] if "ETag" in headers else None
+ etag = (
+ headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+ )
else:
- html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ # we can only get here if we did an oembed request and have an oembed_result.html
+ assert oembed_result.html is not None
+ assert oembed_url is not None
+
+ html_bytes = oembed_result.html.encode("utf-8")
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 56d6afb863..5a5ea39e01 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -25,7 +25,6 @@ from typing import (
Sequence,
Set,
Union,
- cast,
overload,
)
@@ -42,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
-from synapse.types import Collection, MutableStateMap, StateMap
+from synapse.types import Collection, StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -472,10 +471,9 @@ class StateResolutionHandler:
def __init__(self, hs):
self.clock = hs.get_clock()
- # dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+ # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
@@ -519,57 +517,28 @@ class StateResolutionHandler:
Returns:
The resolved state
"""
- logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
-
group_names = frozenset(state_groups_ids.keys())
with (await self.resolve_linearizer.queue(group_names)):
- if self._state_cache is not None:
- cache = self._state_cache.get(group_names, None)
- if cache:
- return cache
+ cache = self._state_cache.get(group_names, None)
+ if cache:
+ return cache
logger.info(
- "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
+ "Resolving state for %s with groups %s", room_id, list(group_names),
)
state_groups_histogram.observe(len(state_groups_ids))
- # start by assuming we won't have any conflicted state, and build up the new
- # state map by iterating through the state groups. If we discover a conflict,
- # we give up and instead use `resolve_events_with_store`.
- #
- # XXX: is this actually worthwhile, or should we just let
- # resolve_events_with_store do it?
- new_state = {} # type: MutableStateMap[str]
- conflicted_state = False
- for st in state_groups_ids.values():
- for key, e_id in st.items():
- if key in new_state:
- conflicted_state = True
- break
- new_state[key] = e_id
- if conflicted_state:
- break
-
- if conflicted_state:
- logger.info("Resolving conflicted state for %r", room_id)
- with Measure(self.clock, "state._resolve_events"):
- # resolve_events_with_store returns a StateMap, but we can
- # treat it as a MutableStateMap as it is above. It isn't
- # actually mutated anymore (and is frozen in
- # _make_state_cache_entry below).
- new_state = cast(
- MutableStateMap,
- await resolve_events_with_store(
- self.clock,
- room_id,
- room_version,
- list(state_groups_ids.values()),
- event_map=event_map,
- state_res_store=state_res_store,
- ),
- )
+ with Measure(self.clock, "state._resolve_events"):
+ new_state = await resolve_events_with_store(
+ self.clock,
+ room_id,
+ room_version,
+ list(state_groups_ids.values()),
+ event_map=event_map,
+ state_res_store=state_res_store,
+ )
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@@ -579,8 +548,7 @@ class StateResolutionHandler:
with Measure(self.clock, "state.create_group_ids"):
cache = _make_state_cache_entry(new_state, state_groups_ids)
- if self._state_cache is not None:
- self._state_cache[group_names] = cache
+ self._state_cache[group_names] = cache
return cache
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ccb3384db9..0cb12f4c61 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,14 +160,20 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
+ # We set the `writers` to an empty list here as we don't care about
+ # missing updates over restarts, as we'll not have anything in our
+ # caches to invalidate. (This reduces the amount of writes to the DB
+ # that happen).
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
- instance_name="master",
+ stream_name="caches",
+ instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
+ writers=[],
)
else:
self._cache_id_gen = None
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index c5a36990e4..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index e4fd979a33..2d151b9134 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -394,7 +394,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -443,7 +443,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index c04374e43d..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with await self._device_list_id_gen.get_next() as stream_id:
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c8df0bcb3f..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data
"""
- with await self._cross_signing_id_gen.get_next() as stream_id:
+ async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a80f419e3..18def01f50 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -156,15 +156,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
@@ -1108,6 +1108,10 @@ class PersistEventsStore:
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
+
+ def str_or_none(val: Any) -> Optional[str]:
+ return val if isinstance(val, str) else None
+
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
@@ -1118,8 +1122,8 @@ class PersistEventsStore:
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
+ "display_name": str_or_none(event.content.get("displayname")),
+ "avatar_url": str_or_none(event.content.get("avatar_url")),
}
for event in events
],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index de9e8d1dc6..f95679ebc4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ stream_name="events",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
+ writers=hs.config.worker.writers.events,
)
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ stream_name="backfill",
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
+ writers=hs.config.worker.writers.events,
)
else:
# We shouldn't be running in worker mode with SQLite, but its useful
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with await self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e20a16f907..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
NotFoundError if the rule does not exist.
"""
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
@@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5867d52b62..c10a16ffa3 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -577,7 +577,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- with await self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 675e81fe34..48ce7ecd16 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user",
)
+ async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Whether the user account has expired
+ """
+ expiration_ts = await self.get_expiration_ts_for_user(user_id)
+ return expiration_ts is not None and current_ts >= expiration_ts
+
async def set_account_validity_for_user(
self,
user_id: str,
@@ -379,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore):
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
- ) -> str:
+ ) -> Optional[str]:
"""Look up a user by their external auth id
Args:
@@ -387,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore):
external_id: id on that system
Returns:
- str|None: the mxid of the user, or None if they are not known
+ the mxid of the user, or None if they are not known
"""
return await self.db_pool.simple_select_one_onecol(
table="user_external_ids",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index bd6f9553c6..3c7630857f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
desc="add_event_report",
)
+ async def get_event_reports_paginate(
+ self,
+ start: int,
+ limit: int,
+ direction: str = "b",
+ user_id: Optional[str] = None,
+ room_id: Optional[str] = None,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Retrieve a paginated list of event reports
+
+ Args:
+ start: event offset to begin the query from
+ limit: number of rows to retrieve
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`)
+ user_id: search for user_id. Ignored if user_id is None
+ room_id: search for room_id. Ignored if room_id is None
+ Returns:
+ event_reports: json list of event reports
+ count: total number of event reports matching the filter criteria
+ """
+
+ def _get_event_reports_paginate_txn(txn):
+ filters = []
+ args = []
+
+ if user_id:
+ filters.append("er.user_id LIKE ?")
+ args.extend(["%" + user_id + "%"])
+ if room_id:
+ filters.append("er.room_id LIKE ?")
+ args.extend(["%" + room_id + "%"])
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql = """
+ SELECT COUNT(*) as total_event_reports
+ FROM event_reports AS er
+ {}
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.reason,
+ er.content,
+ events.sender,
+ room_aliases.room_alias,
+ event_json.json AS event_json
+ FROM event_reports AS er
+ LEFT JOIN room_aliases
+ ON room_aliases.room_id = er.room_id
+ JOIN events
+ ON events.event_id = er.event_id
+ JOIN event_json
+ ON event_json.event_id = er.event_id
+ {where_clause}
+ ORDER BY er.received_ts {order}
+ LIMIT ?
+ OFFSET ?
+ """.format(
+ where_clause=where_clause, order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+ event_reports = self.db_pool.cursor_to_dict(txn)
+
+ if count > 0:
+ for row in event_reports:
+ try:
+ row["content"] = db_to_json(row["content"])
+ row["event_json"] = db_to_json(row["event_json"])
+ except Exception:
+ continue
+
+ return event_reports, count
+
+ return await self.db_pool.runInteraction(
+ "get_event_reports_paginate", _get_event_reports_paginate_txn
+ )
+
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4fa8767b01..86ffe2479e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
@@ -37,7 +36,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# for rooms the server is participating in.
if self._current_state_events_membership_up_to_date:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
@@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
else:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+ return frozenset(
+ GetRoomsForUserWithStreamOrdering(
+ room_id, PersistedEventPosition(instance, stream_id)
+ )
+ for room_id, instance, stream_id in txn
+ )
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..ccf287971c 100644
--- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
@@ -13,7 +13,7 @@
* limitations under the License.
*/
--- room_id and topoligical_ordering are denormalised from the events table in order to
+-- room_id and topological_ordering are denormalised from the events table in order to
-- make the index work.
CREATE TABLE IF NOT EXISTS event_labels (
event_id TEXT,
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
index 97c1e6a0c5..c31f9af82a 100644
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', (
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+-- If the server has never backfilled a room then doing `-MIN(...)` will give
+-- a negative result, hence why we do `GREATEST(...)`
SELECT setval('events_backfill_stream_seq', (
- SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+ SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
));
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+ stream_name TEXT NOT NULL,
+ instance_name TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7816a8606..5beb302be3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -210,6 +210,7 @@ class StatsStore(StateDeltasStore):
* topic
* avatar
* canonical_alias
+ * guest_access
A is_federatable key can also be included with a boolean value.
@@ -234,6 +235,7 @@ class StatsStore(StateDeltasStore):
"topic",
"avatar",
"canonical_alias",
+ "guest_access",
):
field = fields.get(col, sentinel)
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
return
# They are there, delete them.
- self.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, "erased_users", keyvalues={"user_id": user_id}
)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index d89f6ed128..603cd7d825 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -190,6 +190,7 @@ class EventsPersistenceStorage:
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -198,7 +199,7 @@ class EventsPersistenceStorage:
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ) -> int:
+ ) -> RoomStreamToken:
"""
Write events to the database
Args:
@@ -228,11 +229,11 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True)
)
- return self.main_store.get_current_events_token()
+ return RoomStreamToken(None, self.main_store.get_current_events_token())
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ) -> Tuple[int, int]:
+ ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
"""
Returns:
The stream ordering of `event`, and the stream ordering of the
@@ -247,7 +248,10 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred)
max_persisted_id = self.main_store.get_current_events_token()
- return (event.internal_metadata.stream_ordering, max_persisted_id)
+ event_stream_id = event.internal_metadata.stream_ordering
+
+ pos = PersistedEventPosition(self._instance_name, event_stream_id)
+ return pos, RoomStreamToken(None, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
)
GetRoomsForUserWithStreamOrdering = namedtuple(
- "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+ "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..4269eaf918 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,16 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.util.sequence import PostgresSequenceGenerator
@@ -86,7 +87,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +102,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +114,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +122,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +141,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +150,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
+ stream_name: A name for the stream.
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ writers: A list of known writers to use to populate current positions
+ on startup. Can be empty if nothing uses `get_current_token` or
+ `get_positions` (e.g. caches stream).
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
self,
db_conn,
db: DatabasePool,
+ stream_name: str,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
sequence_name: str,
+ writers: List[str],
positive: bool = True,
):
self._db = db
+ self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
+ self._writers = writers
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
@@ -216,9 +225,7 @@ class MultiWriterIdGenerator:
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
- self._current_positions = self._load_current_ids(
- db_conn, table, instance_column, id_column
- )
+ self._current_positions = {} # type: Dict[str, int]
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
@@ -251,30 +258,84 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+ # This goes and fills out the above state from the database.
+ self._load_current_ids(db_conn, table, instance_column, id_column)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
- ) -> Dict[str, int]:
- # If positive stream aggregate via MAX. For negative stream use MIN
- # *and* negate the result to get a positive number.
- sql = """
- SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
- GROUP BY %(instance)s
- """ % {
- "instance": instance_column,
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
-
+ ):
cur = db_conn.cursor()
- cur.execute(sql)
- # `cur` is an iterable over returned rows, which are 2-tuples.
- current_positions = dict(cur)
+ # Load the current positions of all writers for the stream.
+ if self._writers:
+ sql = """
+ SELECT instance_name, stream_id FROM stream_positions
+ WHERE stream_name = ?
+ """
+ sql = self._db.engine.convert_param_style(sql)
- cur.close()
+ cur.execute(sql, (self._stream_name,))
- return current_positions
+ self._current_positions = {
+ instance: stream_id * self._return_factor
+ for instance, stream_id in cur
+ if instance in self._writers
+ }
+
+ # We set the `_persisted_upto_position` to be the minimum of all current
+ # positions. If empty we use the max stream ID from the DB table.
+ min_stream_id = min(self._current_positions.values(), default=None)
+
+ if min_stream_id is None:
+ # We add a GREATEST here to ensure that the result is always
+ # positive. (This can be a problem for e.g. backfill streams where
+ # the server has never backfilled).
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+ self._persisted_upto_position = stream_id
+ else:
+ # If we have a min_stream_id then we pull out everything greater
+ # than it from the DB so that we can prefill
+ # `_known_persisted_positions` and get a more accurate
+ # `_persisted_upto_position`.
+ #
+ # We also check if any of the later rows are from this instance, in
+ # which case we use that for this instance's current position. This
+ # is to handle the case where we didn't finish persisting to the
+ # stream positions table before restart (or the stream position
+ # table otherwise got out of date).
+
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ sql = self._db.engine.convert_param_style(sql)
+ cur.execute(sql, (min_stream_id,))
+
+ self._persisted_upto_position = min_stream_id
+
+ with self._lock:
+ for (instance, stream_id,) in cur:
+ stream_id = self._return_factor * stream_id
+ self._add_persisted_position(stream_id)
+
+ if instance == self._instance_name:
+ self._current_positions[instance] = stream_id
+
+ cur.close()
def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
@@ -282,59 +343,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
-
- self._unfinished_ids.add(next_id)
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
+ return _MultiWriterCtxManager(self)
- return manager()
-
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
-
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -352,6 +377,21 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persited row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
@@ -482,3 +522,95 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+ def _update_stream_positions_table_txn(self, txn):
+ """Update the `stream_positions` table with newly persisted position.
+ """
+
+ if not self._writers:
+ return
+
+ # We upsert the value, ensuring on conflict that we always increase the
+ # value (or decrease if stream goes backwards).
+ sql = """
+ INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+ VALUES (?, ?, ?)
+ ON CONFLICT (stream_name, instance_name)
+ DO UPDATE SET
+ stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+ """ % {
+ "agg": "GREATEST" if self._positive else "LEAST",
+ }
+
+ pos = (self.get_current_token_for_writer(self._instance_name),)
+ txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ if self.id_gen._writers:
+ await self.id_gen._db.runInteraction(
+ "MultiWriterIdGenerator._update_table",
+ self.id_gen._update_stream_positions_table_txn,
+ )
+
+ return False
diff --git a/synapse/types.py b/synapse/types.py
index a6fc7df22c..ec39f9e1e8 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -495,6 +495,21 @@ class StreamToken:
StreamToken.START = StreamToken.from_string("s0_0")
+@attr.s(slots=True, frozen=True)
+class PersistedEventPosition:
+ """Position of a newly persisted event with instance that persisted it.
+
+ This can be used to test whether the event is persisted before or after a
+ RoomStreamToken.
+ """
+
+ instance_name = attr.ib(type=str)
+ stream = attr.ib(type=int)
+
+ def persisted_after(self, token: RoomStreamToken) -> bool:
+ return token.stream < self.stream
+
+
class ThirdPartyInstanceID(
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2e6e7abf1f..5cf408f21f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
@@ -33,7 +34,6 @@ from synapse.crypto.keyring import (
)
from synapse.logging.context import (
LoggingContext,
- PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
@@ -68,54 +68,40 @@ class MockPerspectiveServer:
class KeyringTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- self.mock_perspective_server = MockPerspectiveServer()
- self.http_client = Mock()
-
- config = self.default_config()
- config["trusted_key_servers"] = [
- {
- "server_name": self.mock_perspective_server.server_name,
- "verify_keys": self.mock_perspective_server.get_verify_keys(),
- }
- ]
-
- return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
- )
-
- def check_context(self, _, expected):
+ def check_context(self, val, expected):
self.assertEquals(getattr(current_context(), "request", None), expected)
+ return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- key1 = signedjson.key.generate_signing_key(1)
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock()
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
- kr = keyring.Keyring(self.hs)
+ # a signed object that we are going to try to validate
+ key1 = signedjson.key.generate_signing_key(1)
json1 = {}
signedjson.sign.sign_json(json1, "server10", key1)
- persp_resp = {
- "server_keys": [
- self.mock_perspective_server.get_signed_key(
- "server10", signedjson.key.get_verify_key(key1)
- )
- ]
- }
- persp_deferred = defer.Deferred()
+ # start off a first set of lookups. We make the mock fetcher block until this
+ # deferred completes.
+ first_lookup_deferred = Deferred()
+
+ async def first_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_11")
+ self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
- async def get_perspectives(**kwargs):
- self.assertEquals(current_context().request, "11")
- with PreserveLoggingContext():
- await persp_deferred
- return persp_resp
+ await make_deferred_yieldable(first_lookup_deferred)
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
- self.http_client.post_json.side_effect = get_perspectives
+ mock_fetcher.get_keys.side_effect = first_lookup_fetch
- # start off a first set of lookups
- @defer.inlineCallbacks
- def first_lookup():
- with LoggingContext("11") as context_11:
- context_11.request = "11"
+ async def first_lookup():
+ with LoggingContext("context_11") as context_11:
+ context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
- yield res_deferreds[1]
+ await res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
@@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds[0])
+ await make_deferred_yieldable(res_deferreds[0])
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
+ d0 = ensureDeferred(first_lookup())
- d0 = first_lookup()
-
- # wait a tick for it to send the request to the perspectives server
- # (it first tries the datastore)
- self.pump()
- self.http_client.post_json.assert_called_once()
+ mock_fetcher.get_keys.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
- @defer.inlineCallbacks
- def second_lookup():
- with LoggingContext("12") as context_12:
- context_12.request = "12"
- self.http_client.post_json.reset_mock()
- self.http_client.post_json.return_value = defer.Deferred()
+
+ async def second_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_12")
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.reset_mock()
+ mock_fetcher.get_keys.side_effect = second_lookup_fetch
+ second_lookup_state = [0]
+
+ async def second_lookup():
+ with LoggingContext("context_12") as context_12:
+ context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 1
+ await make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 2
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
-
- d2 = second_lookup()
+ d2 = ensureDeferred(second_lookup())
self.pump()
- self.http_client.post_json.assert_not_called()
+ # the second request should be pending, but the fetcher should not yet have been
+ # called
+ self.assertEqual(second_lookup_state[0], 1)
+ mock_fetcher.get_keys.assert_not_called()
# complete the first request
- persp_deferred.callback(persp_resp)
+ first_lookup_deferred.callback(None)
+
+ # and now both verifications should succeed.
self.get_success(d0)
self.get_success(d2)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 6aa322bf3a..969d44c787 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
+ def test_device_is_created_with_invalid_name(self):
+ self.get_failure(
+ self.handler.check_device_registered(
+ user_id="@boris:foo",
+ device_id="foo",
+ initial_device_display_name="a"
+ * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+ ),
+ synapse.api.errors.SynapseError,
+ )
+
def test_device_is_created_if_doesnt_exist(self):
res = self.get_success(
self.handler.check_device_registered(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 89ec5fcb31..5910772aa8 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
self.assertEqual(mxid, "@test_user_2:test")
+
+ # Test if the mxid is already taken
+ store = self.hs.get_datastore()
+ user3 = UserID.from_string("@test_user_3:test")
+ self.get_success(
+ store.register_user(user_id=user3.to_string(), password_hash=None)
+ )
+ userinfo = {"sub": "test3", "username": "test_user_3"}
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+
+ @override_config({"oidc_config": {"allow_existing_users": True}})
+ def test_map_userinfo_to_existing_user(self):
+ """Existing users can log in with OpenID Connect when allow_existing_users is True."""
+ store = self.hs.get_datastore()
+ user4 = UserID.from_string("@test_user_4:test")
+ self.get_success(
+ store.register_user(user_id=user4.to_string(), password_hash=None)
+ )
+ userinfo = {
+ "sub": "test4",
+ "username": "test_user_4",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user_4:test")
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index bc578411d6..c0ee1cfbd6 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from synapse.types import PersistedEventPosition
from tests.server import FakeTransport
@@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
self.replicate()
+
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
+ )
self.check(
"get_rooms_for_user_with_stream_ordering",
(USER_ID_2,),
- {(ROOM_ID, j2.internal_metadata.stream_ordering)},
+ {(ROOM_ID, expected_pos)},
)
def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# the membership change is only any use to us if the room is in the
# joined_rooms list.
if membership_changes:
- self.assertEqual(
- joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)}
+ expected_pos = PersistedEventPosition(
+ "master", j2.internal_metadata.stream_ordering
)
+ self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
event_id = 0
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index faa7f381a9..92c9058887 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
# Ensure the display name was not updated.
request, channel = self.make_request(
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
new file mode 100644
index 0000000000..bf79086f78
--- /dev/null
+++ b/tests/rest/admin/test_event_reports.py
@@ -0,0 +1,382 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import report_event
+
+from tests import unittest
+
+
+class EventReportsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ report_event.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.room_id1 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok)
+
+ self.room_id2 = self.helper.create_room_as(
+ self.other_user, tok=self.other_user_tok, is_public=True
+ )
+ self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok)
+
+ # Two rooms and two users. Every user sends and reports every room event
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.other_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id1, user_tok=self.admin_user_tok,
+ )
+ for i in range(5):
+ self._create_event_and_report(
+ room_id=self.room_id2, user_tok=self.admin_user_tok,
+ )
+
+ self.url = "/_synapse/admin/v1/event_reports"
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.other_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_default_success(self):
+ """
+ Testing list of reported events
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url, access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit(self):
+ """
+ Testing list of reported events with limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertEqual(channel.json_body["next_token"], 5)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_from(self):
+ """
+ Testing list of reported events with a defined starting point (from)
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 15)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_limit_and_from(self):
+ """
+ Testing list of reported events with a defined starting point and limit
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(channel.json_body["next_token"], 15)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self._check_fields(channel.json_body["event_reports"])
+
+ def test_filter_room(self):
+ """
+ Testing list of reported events with a filter of room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?room_id=%s" % self.room_id1,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_filter_user(self):
+ """
+ Testing list of reported events with a filter of user
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s" % self.other_user,
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 10)
+ self.assertEqual(len(channel.json_body["event_reports"]), 10)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+
+ def test_filter_user_and_room(self):
+ """
+ Testing list of reported events with a filter of user and room
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 5)
+ self.assertEqual(len(channel.json_body["event_reports"]), 5)
+ self.assertNotIn("next_token", channel.json_body)
+ self._check_fields(channel.json_body["event_reports"])
+
+ for report in channel.json_body["event_reports"]:
+ self.assertEqual(report["user_id"], self.other_user)
+ self.assertEqual(report["room_id"], self.room_id1)
+
+ def test_valid_search_order(self):
+ """
+ Testing search order. Order by timestamps.
+ """
+
+ # fetch the most recent first, largest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertGreaterEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ # fetch the oldest first, smallest timestamp
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ report = 1
+ while report < len(channel.json_body["event_reports"]):
+ self.assertLessEqual(
+ channel.json_body["event_reports"][report - 1]["received_ts"],
+ channel.json_body["event_reports"][report]["received_ts"],
+ )
+ report += 1
+
+ def test_invalid_search_order(self):
+ """
+ Testing that a invalid search order returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+ self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+
+ def test_limit_is_negative(self):
+ """
+ Testing that a negative list parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_from_is_negative(self):
+ """
+ Testing that a negative from parameter returns a 400
+ """
+
+ request, channel = self.make_request(
+ "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_next_token(self):
+ """
+ Testing that `next_token` appears at the right place
+ """
+
+ # `next_token` does not appear
+ # Number of results is the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does not appear
+ # Number of max results is larger than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 20)
+ self.assertNotIn("next_token", channel.json_body)
+
+ # `next_token` does appear
+ # Number of max results is smaller than the number of entries
+ request, channel = self.make_request(
+ "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 19)
+ self.assertEqual(channel.json_body["next_token"], 19)
+
+ # Check
+ # Set `from` to value of `next_token` for request remaining entries
+ # `next_token` does not appear
+ request, channel = self.make_request(
+ "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["total"], 20)
+ self.assertEqual(len(channel.json_body["event_reports"]), 1)
+ self.assertNotIn("next_token", channel.json_body)
+
+ def _create_event_and_report(self, room_id, user_tok):
+ """Create and report events
+ """
+ resp = self.helper.send(room_id, tok=user_tok)
+ event_id = resp["event_id"]
+
+ request, channel = self.make_request(
+ "POST",
+ "rooms/%s/report/%s" % (room_id, event_id),
+ json.dumps({"score": -100, "reason": "this makes me sad"}),
+ access_token=user_tok,
+ )
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _check_fields(self, content):
+ """Checks that all attributes are present in a event report
+ """
+ for c in content:
+ self.assertIn("id", c)
+ self.assertIn("received_ts", c)
+ self.assertIn("room_id", c)
+ self.assertIn("event_id", c)
+ self.assertIn("user_id", c)
+ self.assertIn("reason", c)
+ self.assertIn("content", c)
+ self.assertIn("sender", c)
+ self.assertIn("room_alias", c)
+ self.assertIn("event_json", c)
+ self.assertIn("score", c["content"])
+ self.assertIn("reason", c["content"])
+ self.assertIn("auth_events", c["event_json"])
+ self.assertIn("type", c["event_json"])
+ self.assertIn("room_id", c["event_json"])
+ self.assertIn("sender", c["event_json"])
+ self.assertIn("content", c["event_json"])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f96011fc1c..98d0623734 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self._is_erased("@user:test", False)
+ d = self.store.mark_user_erased("@user:test")
+ self.assertIsNone(self.get_success(d))
+ self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password).
request, channel = self.make_request(
@@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
+ self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
"""
@@ -996,6 +1001,15 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
+ def _is_erased(self, user_id, expect):
+ """Assert that the user is erased or not
+ """
+ d = self.store.is_user_erased(user_id)
+ if expect:
+ self.assertTrue(self.get_success(d))
+ else:
+ self.assertFalse(self.get_success(d))
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 20636fc400..d4ff55fbff 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
)
return self.get_success(self.db_pool.runWithConnection(_create))
@@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, stream_id, stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
@@ -111,7 +129,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await id_gen.get_next() as stream_id:
+ async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
@@ -139,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
ctx3 = self.get_success(id_gen.get_next())
ctx4 = self.get_success(id_gen.get_next())
- s1 = ctx1.__enter__()
- s2 = ctx2.__enter__()
- s3 = ctx3.__enter__()
- s4 = ctx4.__enter__()
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+ s3 = self.get_success(ctx3.__aenter__())
+ s4 = self.get_success(ctx4.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
@@ -152,22 +170,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx2.__exit__(None, None, None)
+ self.get_success(ctx2.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx1.__exit__(None, None, None)
+ self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
- ctx4.__exit__(None, None, None)
+ self.get_success(ctx4.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
- ctx3.__exit__(None, None, None)
+ self.get_success(ctx3.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
@@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_rows("first", 3)
self._insert_rows("second", 4)
- first_id_gen = self._create_id_generator("first")
- second_id_gen = self._create_id_generator("second")
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
@@ -190,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await first_id_gen.get_next() as stream_id:
+ async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
@@ -208,7 +226,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# stream ID
async def _get_next_async():
- with await second_id_gen.get_next() as stream_id:
+ async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
self.assertEqual(
@@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -300,14 +318,18 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
- with self.get_success(id_gen.get_next()) as stream_id:
- self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
@@ -315,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+ def test_restart_during_out_of_order_persistence(self):
+ """Test that restarting a process while another process is writing out
+ of order updates are handled correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # Persist two rows at once
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # We finish persisting the second row before restart
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ # We simulate a restart of another worker by just creating a new ID gen.
+ id_gen_worker = self._create_id_generator("worker")
+
+ # Restarted worker should not see the second persisted row
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+ self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+
+ # Now if we persist the first row then both instances should jump ahead
+ # correctly.
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ id_gen_worker.advance("master", 9)
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+
+ def test_writer_config_change(self):
+ """Test that changing the writer config correctly works.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ # Initial config has two writers
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # New config removes one of the configs. Note that if the writer is
+ # removed from config we assume that it has been shut down and has
+ # finished persisting, hence why the persisted upto position is 5.
+ id_gen_2 = self._create_id_generator("second", writers=["second"])
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+
+ # This config points to a single, previously unused writer.
+ id_gen_3 = self._create_id_generator("third", writers=["third"])
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+
+ # Check that we get a sane next stream ID with this new config.
+
+ async def _get_next_async():
+ async with id_gen_3.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+
+ self.get_success(_get_next_async())
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
@@ -341,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
positive=False,
)
@@ -364,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, -stream_id, -stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
@@ -373,16 +480,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
id_gen = self._create_id_generator()
- with self.get_success(id_gen.get_next()) as stream_id:
- self._insert_row("master", stream_id)
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
- with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
- for stream_id in stream_ids:
- self._insert_row("master", stream_id)
+ async def _get_next_async2():
+ async with id_gen.get_next_mult(3) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
@@ -399,21 +512,27 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
- id_gen_1 = self._create_id_generator("first")
- id_gen_2 = self._create_id_generator("second")
+ id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
+ id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
- with self.get_success(id_gen_1.get_next()) as stream_id:
- self._insert_row("first", stream_id)
- id_gen_2.advance("first", stream_id)
+ async def _get_next_async():
+ async with id_gen_1.get_next() as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
- with self.get_success(id_gen_2.get_next()) as stream_id:
- self._insert_row("second", stream_id)
- id_gen_1.advance("second", stream_id)
+ async def _get_next_async2():
+ async with id_gen_2.get_next() as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
diff --git a/tox.ini b/tox.ini
index ddcab0198f..4d132eff4c 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,13 +2,12 @@
envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort
[base]
+extras = test
deps =
- mock
python-subunit
junitxml
coverage
coverage-enable-subprocess
- parameterized
# cyptography 2.2 requires setuptools >= 18.5
#
@@ -36,7 +35,7 @@ setenv =
[testenv]
deps =
{[base]deps}
-extras = all
+extras = all, test
whitelist_externals =
sh
@@ -84,7 +83,6 @@ deps =
# Old automat version for Twisted
Automat == 0.3.0
- mock
lxml
coverage
coverage-enable-subprocess
@@ -97,7 +95,7 @@ commands =
/bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
# Install Synapse itself. This won't update any libraries.
- pip install -e .
+ pip install -e ".[test]"
{envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
|