diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 11fb05ca96..fc22d89426 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -3,6 +3,10 @@
<!-- Please read CONTRIBUTING.md before submitting your pull request -->
* [ ] Pull request is based on the develop branch
-* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#changelog)
+* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#changelog). The entry should:
+ - Be a short description of your change which makes sense to users. "Fixed a bug that prevented receiving messages from other servers." instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
+ - Use markdown where necessary, mostly for `code blocks`.
+ - End with either a period (.) or an exclamation mark (!).
+ - Start with a capital letter.
* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#sign-off)
* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#code-style))
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index c0091346f3..5736ede6c4 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -101,8 +101,8 @@ in the format of `PRnumber.type`. The type can be one of the following:
The content of the file is your changelog entry, which should be a short
description of your change in the same style as the rest of our [changelog](
https://github.com/matrix-org/synapse/blob/master/CHANGES.md). The file can
-contain Markdown formatting, and should end with a full stop ('.') for
-consistency.
+contain Markdown formatting, and should end with a full stop (.) or an
+exclamation mark (!) for consistency.
Adding credits to the changelog is encouraged, we value your
contributions and would like to have you shouted out in the release notes!
diff --git a/changelog.d/6655.misc b/changelog.d/6655.misc
new file mode 100644
index 0000000000..01e78bc84e
--- /dev/null
+++ b/changelog.d/6655.misc
@@ -0,0 +1 @@
+Add `local_current_membership` table for tracking local user membership state in rooms.
diff --git a/changelog.d/6663.doc b/changelog.d/6663.doc
new file mode 100644
index 0000000000..83b9c1626a
--- /dev/null
+++ b/changelog.d/6663.doc
@@ -0,0 +1 @@
+Add some helpful tips about changelog entries to the github pull request template.
\ No newline at end of file
diff --git a/changelog.d/6666.misc b/changelog.d/6666.misc
new file mode 100644
index 0000000000..e79c23d2d2
--- /dev/null
+++ b/changelog.d/6666.misc
@@ -0,0 +1 @@
+Port `synapse.replication.tcp` to async/await.
diff --git a/changelog.d/6685.doc b/changelog.d/6685.doc
new file mode 100644
index 0000000000..7cf750fe3f
--- /dev/null
+++ b/changelog.d/6685.doc
@@ -0,0 +1 @@
+Clarify the `account_validity` and `email` sections of the sample configuration.
\ No newline at end of file
diff --git a/changelog.d/6687.misc b/changelog.d/6687.misc
new file mode 100644
index 0000000000..deb0454602
--- /dev/null
+++ b/changelog.d/6687.misc
@@ -0,0 +1 @@
+Allow REST endpoint implementations to raise a RedirectException, which will redirect the user's browser to a given location.
diff --git a/changelog.d/6688.misc b/changelog.d/6688.misc
new file mode 100644
index 0000000000..2a9f28ce5c
--- /dev/null
+++ b/changelog.d/6688.misc
@@ -0,0 +1 @@
+Updates and extensions to the module API.
\ No newline at end of file
diff --git a/changelog.d/6702.misc b/changelog.d/6702.misc
new file mode 100644
index 0000000000..f7bc98409c
--- /dev/null
+++ b/changelog.d/6702.misc
@@ -0,0 +1 @@
+Remove duplicate check for the `session` query parameter on the `/auth/xxx/fallback/web` Client-Server endpoint.
\ No newline at end of file
diff --git a/changelog.d/6706.misc b/changelog.d/6706.misc
new file mode 100644
index 0000000000..1ac11cc04b
--- /dev/null
+++ b/changelog.d/6706.misc
@@ -0,0 +1 @@
+Attempt to retry sending a transaction when we detect a remote server has come back online, rather than waiting for a transaction to be triggered by new data.
diff --git a/changelog.d/6711.bugfix b/changelog.d/6711.bugfix
new file mode 100644
index 0000000000..c70506bd88
--- /dev/null
+++ b/changelog.d/6711.bugfix
@@ -0,0 +1 @@
+Fix `purge_room` admin API.
diff --git a/changelog.d/6712.feature b/changelog.d/6712.feature
new file mode 100644
index 0000000000..2cce0ecf88
--- /dev/null
+++ b/changelog.d/6712.feature
@@ -0,0 +1 @@
+Add org.matrix.e2e_cross_signing to unstable_features in /versions as per [MSC1756](https://github.com/matrix-org/matrix-doc/pull/1756).
diff --git a/changelog.d/6715.misc b/changelog.d/6715.misc
new file mode 100644
index 0000000000..8876b0446d
--- /dev/null
+++ b/changelog.d/6715.misc
@@ -0,0 +1 @@
+Add StateMap type alias to simplify types.
diff --git a/changelog.d/6723.misc b/changelog.d/6723.misc
new file mode 100644
index 0000000000..17f15e73a8
--- /dev/null
+++ b/changelog.d/6723.misc
@@ -0,0 +1 @@
+Updates to the SAML mapping provider API.
diff --git a/changelog.d/6724.misc b/changelog.d/6724.misc
new file mode 100644
index 0000000000..5256be75fa
--- /dev/null
+++ b/changelog.d/6724.misc
@@ -0,0 +1 @@
+When processing a SAML response, log the assertions for easier configuration.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 0a2505e7bb..8e8cf513b0 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -874,23 +874,6 @@ media_store_path: "DATADIR/media_store"
# Optional account validity configuration. This allows for accounts to be denied
# any request after a given period.
#
-# ``enabled`` defines whether the account validity feature is enabled. Defaults
-# to False.
-#
-# ``period`` allows setting the period after which an account is valid
-# after its registration. When renewing the account, its validity period
-# will be extended by this amount of time. This parameter is required when using
-# the account validity feature.
-#
-# ``renew_at`` is the amount of time before an account's expiry date at which
-# Synapse will send an email to the account's email address with a renewal link.
-# This needs the ``email`` and ``public_baseurl`` configuration sections to be
-# filled.
-#
-# ``renew_email_subject`` is the subject of the email sent out with the renewal
-# link. ``%(app)s`` can be used as a placeholder for the ``app_name`` parameter
-# from the ``email`` section.
-#
# Once this feature is enabled, Synapse will look for registered users without an
# expiration date at startup and will add one to every account it found using the
# current settings at that time.
@@ -901,21 +884,55 @@ media_store_path: "DATADIR/media_store"
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10% of the validity period.
#
-#account_validity:
-# enabled: true
-# period: 6w
-# renew_at: 1w
-# renew_email_subject: "Renew your %(app)s account"
-# # Directory in which Synapse will try to find the HTML files to serve to the
-# # user when trying to renew an account. Optional, defaults to
-# # synapse/res/templates.
-# template_dir: "res/templates"
-# # HTML to be displayed to the user after they successfully renewed their
-# # account. Optional.
-# account_renewed_html_path: "account_renewed.html"
-# # HTML to be displayed when the user tries to renew an account with an invalid
-# # renewal token. Optional.
-# invalid_token_html_path: "invalid_token.html"
+account_validity:
+ # The account validity feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # The period after which an account is valid after its registration. When
+ # renewing the account, its validity period will be extended by this amount
+ # of time. This parameter is required when using the account validity
+ # feature.
+ #
+ #period: 6w
+
+ # The amount of time before an account's expiry date at which Synapse will
+ # send an email to the account's email address with a renewal link. By
+ # default, no such emails are sent.
+ #
+ # If you enable this setting, you will also need to fill out the 'email' and
+ # 'public_baseurl' configuration sections.
+ #
+ #renew_at: 1w
+
+ # The subject of the email sent out with the renewal link. '%(app)s' can be
+ # used as a placeholder for the 'app_name' parameter from the 'email'
+ # section.
+ #
+ # Note that the placeholder must be written '%(app)s', including the
+ # trailing 's'.
+ #
+ # If this is not set, a default value is used.
+ #
+ #renew_email_subject: "Renew your %(app)s account"
+
+ # Directory in which Synapse will try to find templates for the HTML files to
+ # serve to the user when trying to renew an account. If not set, default
+ # templates from within the Synapse package will be used.
+ #
+ #template_dir: "res/templates"
+
+ # File within 'template_dir' giving the HTML to be displayed to the user after
+ # they successfully renewed their account. If not set, default text is used.
+ #
+ #account_renewed_html_path: "account_renewed.html"
+
+ # File within 'template_dir' giving the HTML to be displayed when the user
+ # tries to renew an account with an invalid renewal token. If not set,
+ # default text is used.
+ #
+ #invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in.
#
@@ -1353,107 +1370,110 @@ password_config:
#pepper: "EVEN_MORE_SECRET"
+# Configuration for sending emails from Synapse.
+#
+email:
+ # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'.
+ #
+ #smtp_host: mail.server
+
+ # The port on the mail server for outgoing SMTP. Defaults to 25.
+ #
+ #smtp_port: 587
+
+ # Username/password for authentication to the SMTP server. By default, no
+ # authentication is attempted.
+ #
+ # smtp_user: "exampleusername"
+ # smtp_pass: "examplepassword"
+
+ # Uncomment the following to require TLS transport security for SMTP.
+ # By default, Synapse will connect over plain text, and will then switch to
+ # TLS via STARTTLS *if the SMTP server supports it*. If this option is set,
+ # Synapse will refuse to connect unless the server supports STARTTLS.
+ #
+ #require_transport_security: true
+
+ # Enable sending emails for messages that the user has missed
+ #
+ #enable_notifs: false
+
+ # notif_from defines the "From" address to use when sending emails.
+ # It must be set if email sending is enabled.
+ #
+ # The placeholder '%(app)s' will be replaced by the application name,
+ # which is normally 'app_name' (below), but may be overridden by the
+ # Matrix client application.
+ #
+ # Note that the placeholder must be written '%(app)s', including the
+ # trailing 's'.
+ #
+ #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
+
+ # app_name defines the default value for '%(app)s' in notif_from. It
+ # defaults to 'Matrix'.
+ #
+ #app_name: my_branded_matrix_server
+
+ # Uncomment the following to disable automatic subscription to email
+ # notifications for new users. Enabled by default.
+ #
+ #notif_for_new_users: false
+
+ # Custom URL for client links within the email notifications. By default
+ # links will be based on "https://matrix.to".
+ #
+ # (This setting used to be called riot_base_url; the old name is still
+ # supported for backwards-compatibility but is now deprecated.)
+ #
+ #client_base_url: "http://localhost/riot"
-# Enable sending emails for password resets, notification events or
-# account expiry notices
-#
-# If your SMTP server requires authentication, the optional smtp_user &
-# smtp_pass variables should be used
-#
-#email:
-# enable_notifs: false
-# smtp_host: "localhost"
-# smtp_port: 25 # SSL: 465, STARTTLS: 587
-# smtp_user: "exampleusername"
-# smtp_pass: "examplepassword"
-# require_transport_security: false
-#
-# # notif_from defines the "From" address to use when sending emails.
-# # It must be set if email sending is enabled.
-# #
-# # The placeholder '%(app)s' will be replaced by the application name,
-# # which is normally 'app_name' (below), but may be overridden by the
-# # Matrix client application.
-# #
-# # Note that the placeholder must be written '%(app)s', including the
-# # trailing 's'.
-# #
-# notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
-#
-# # app_name defines the default value for '%(app)s' in notif_from. It
-# # defaults to 'Matrix'.
-# #
-# #app_name: my_branded_matrix_server
-#
-# # Enable email notifications by default
-# #
-# notif_for_new_users: true
-#
-# # Defining a custom URL for Riot is only needed if email notifications
-# # should contain links to a self-hosted installation of Riot; when set
-# # the "app_name" setting is ignored
-# #
-# riot_base_url: "http://localhost/riot"
-#
-# # Configure the time that a validation email or text message code
-# # will expire after sending
-# #
-# # This is currently used for password resets
-# #
-# #validation_token_lifetime: 1h
-#
-# # Template directory. All template files should be stored within this
-# # directory. If not set, default templates from within the Synapse
-# # package will be used
-# #
-# # For the list of default templates, please see
-# # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
-# #
-# #template_dir: res/templates
-#
-# # Templates for email notifications
-# #
-# notif_template_html: notif_mail.html
-# notif_template_text: notif_mail.txt
-#
-# # Templates for account expiry notices
-# #
-# expiry_template_html: notice_expiry.html
-# expiry_template_text: notice_expiry.txt
-#
-# # Templates for password reset emails sent by the homeserver
-# #
-# #password_reset_template_html: password_reset.html
-# #password_reset_template_text: password_reset.txt
-#
-# # Templates for registration emails sent by the homeserver
-# #
-# #registration_template_html: registration.html
-# #registration_template_text: registration.txt
-#
-# # Templates for validation emails sent by the homeserver when adding an email to
-# # your user account
-# #
-# #add_threepid_template_html: add_threepid.html
-# #add_threepid_template_text: add_threepid.txt
-#
-# # Templates for password reset success and failure pages that a user
-# # will see after attempting to reset their password
-# #
-# #password_reset_template_success_html: password_reset_success.html
-# #password_reset_template_failure_html: password_reset_failure.html
-#
-# # Templates for registration success and failure pages that a user
-# # will see after attempting to register using an email or phone
-# #
-# #registration_template_success_html: registration_success.html
-# #registration_template_failure_html: registration_failure.html
-#
-# # Templates for success and failure pages that a user will see after attempting
-# # to add an email or phone to their account
-# #
-# #add_threepid_success_html: add_threepid_success.html
-# #add_threepid_failure_html: add_threepid_failure.html
+ # Configure the time that a validation email will expire after sending.
+ # Defaults to 1h.
+ #
+ #validation_token_lifetime: 15m
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, default templates from within the Synapse package will be used.
+ #
+ # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
+ # If you *do* uncomment it, you will need to make sure that all the templates
+ # below are in the directory.
+ #
+ # Synapse will look for the following templates in this directory:
+ #
+ # * The contents of email notifications of missed events: 'notif_mail.html' and
+ # 'notif_mail.txt'.
+ #
+ # * The contents of account expiry notice emails: 'notice_expiry.html' and
+ # 'notice_expiry.txt'.
+ #
+ # * The contents of password reset emails sent by the homeserver:
+ # 'password_reset.html' and 'password_reset.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in the password reset email: 'password_reset_success.html' and
+ # 'password_reset_failure.html'
+ #
+ # * The contents of address verification emails sent during registration:
+ # 'registration.html' and 'registration.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in an address verification email sent during registration:
+ # 'registration_success.html' and 'registration_failure.html'
+ #
+ # * The contents of address verification emails sent when an address is added
+ # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in an address verification email sent when an address is added
+ # to a Matrix account: 'add_threepid_success.html' and
+ # 'add_threepid_failure.html'
+ #
+ # You can see the default templates at:
+ # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
+ #
+ #template_dir: "res/templates"
#password_providers:
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index ba9e874d07..a0b1d563ff 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -209,7 +209,7 @@ Where `<token>` may be either:
* a numeric stream_id to stream updates since (exclusive)
* `NOW` to stream all subsequent updates.
-The `<stream_name>` is the name of a replication stream to subscribe
+The `<stream_name>` is the name of a replication stream to subscribe
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
of streams). It can also be `ALL` to subscribe to all known streams,
in which case the `<token>` must be set to `NOW`.
@@ -234,6 +234,10 @@ in which case the `<token>` must be set to `NOW`.
Used exclusively in tests
+### REMOTE_SERVER_UP (S, C)
+
+ Inform other processes that a remote server may have come back online.
+
See `synapse/replication/tcp/commands.py` for a detailed description and
the format of each command.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index f135c8bc54..5e69104b97 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -470,7 +470,7 @@ class Porter(object):
engine.check_database(
db_conn, allow_outdated_version=allow_outdated_version
)
- prepare_database(db_conn, engine, config=None)
+ prepare_database(db_conn, engine, config=self.hs_config)
store = Store(Database(hs, db_config, engine), db_conn, hs)
db_conn.commit()
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 0dd538d804..abd5297390 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.8.0"
+__version__ = "1.9.0.dev1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index abbc7079a3..2cbfab2569 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -14,7 +14,6 @@
# limitations under the License.
import logging
-from typing import Dict, Tuple
from six import itervalues
@@ -35,7 +34,7 @@ from synapse.api.errors import (
ResourceLimitError,
)
from synapse.config.server import is_threepid_reserved
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -509,10 +508,7 @@ class Auth(object):
return self.store.is_server_admin(user)
def compute_auth_events(
- self,
- event,
- current_state_ids: Dict[Tuple[str, str], str],
- for_verification: bool = False,
+ self, event, current_state_ids: StateMap[str], for_verification: bool = False,
):
"""Given an event and current state return the list of event IDs used
to auth an event.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 9e9844b47c..1c9456e583 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,13 +17,15 @@
"""Contains exceptions and error codes."""
import logging
-from typing import Dict
+from typing import Dict, List
from six import iteritems
from six.moves import http_client
from canonicaljson import json
+from twisted.web import http
+
logger = logging.getLogger(__name__)
@@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError):
self.msg = msg
+class RedirectException(CodeMessageException):
+ """A pseudo-error indicating that we want to redirect the client to a different
+ location
+
+ Attributes:
+ cookies: a list of set-cookies values to add to the response. For example:
+ b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
+ """
+
+ def __init__(self, location: bytes, http_code: int = http.FOUND):
+ """
+
+ Args:
+ location: the URI to redirect to
+ http_code: the HTTP response code
+ """
+ msg = "Redirect to %s" % (location.decode("utf-8"),)
+ super().__init__(code=http_code, msg=msg)
+ self.location = location
+
+ self.cookies = [] # type: List[bytes]
+
+
class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 8e36bc57d3..1c7c6ec0c8 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer):
class AdminCmdReplicationHandler(ReplicationClientHandler):
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
pass
def get_streams_to_replicate(self):
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index e82e0f11e3..2217d4a4fb 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler):
super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
+ async def on_rdata(self, stream_name, token, rows):
+ await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 83c436229c..38d11fdd0f 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self)
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(FederationSenderReplicationHandler, self).on_rdata(
+ async def on_rdata(self, stream_name, token, rows):
+ await super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows
)
self.send_handler.process_replication_rows(stream_name, token, rows)
@@ -159,6 +158,13 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
args.update(self.send_handler.stream_positions())
return args
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
+ # Let's wake up the transaction queue for the server in case we have
+ # pending stuff to send to it.
+ self.send_handler.wake_destination(server)
+
def start(config_options):
try:
@@ -206,7 +212,7 @@ class FederationSenderHandler(object):
to the federation sender.
"""
- def __init__(self, hs, replication_client):
+ def __init__(self, hs: FederationSenderServer, replication_client):
self.store = hs.get_datastore()
self._is_mine_id = hs.is_mine_id
self.federation_sender = hs.get_federation_sender()
@@ -227,6 +233,9 @@ class FederationSenderHandler(object):
self.store.get_room_max_stream_ordering()
)
+ def wake_destination(self, server: str):
+ self.federation_sender.wake_destination(server)
+
def stream_positions(self):
return {"federation": self.federation_position}
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 09e639040a..e46b6ac598 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool()
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+ async def on_rdata(self, stream_name, token, rows):
+ await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 03031ee34d..3218da07bd 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
+ async def on_rdata(self, stream_name, token, rows):
+ await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self):
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 1257098f92..ba536d6f04 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler()
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(UserDirectoryReplicationHandler, self).on_rdata(
+ async def on_rdata(self, stream_name, token, rows):
+ await super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows
)
if stream_name == EventsStream.NAME:
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 35756bed87..74853f9faa 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -37,10 +37,12 @@ class EmailConfig(Config):
self.email_enable_notifs = False
- email_config = config.get("email", {})
+ email_config = config.get("email")
+ if email_config is None:
+ email_config = {}
- self.email_smtp_host = email_config.get("smtp_host", None)
- self.email_smtp_port = email_config.get("smtp_port", None)
+ self.email_smtp_host = email_config.get("smtp_host", "localhost")
+ self.email_smtp_port = email_config.get("smtp_port", 25)
self.email_smtp_user = email_config.get("smtp_user", None)
self.email_smtp_pass = email_config.get("smtp_pass", None)
self.require_transport_security = email_config.get(
@@ -74,9 +76,9 @@ class EmailConfig(Config):
self.email_template_dir = os.path.abspath(template_dir)
self.email_enable_notifs = email_config.get("enable_notifs", False)
- account_validity_renewal_enabled = config.get("account_validity", {}).get(
- "renew_at"
- )
+
+ account_validity_config = config.get("account_validity") or {}
+ account_validity_renewal_enabled = account_validity_config.get("renew_at")
self.threepid_behaviour_email = (
# Have Synapse handle the email sending if account_threepid_delegates.email
@@ -278,7 +280,9 @@ class EmailConfig(Config):
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
- self.email_riot_base_url = email_config.get("riot_base_url", None)
+ self.email_riot_base_url = email_config.get(
+ "client_base_url", email_config.get("riot_base_url", None)
+ )
if account_validity_renewal_enabled:
self.email_expiry_template_html = email_config.get(
@@ -294,107 +298,111 @@ class EmailConfig(Config):
raise ConfigError("Unable to find email template file %s" % (p,))
def generate_config_section(self, config_dir_path, server_name, **kwargs):
- return """
- # Enable sending emails for password resets, notification events or
- # account expiry notices
- #
- # If your SMTP server requires authentication, the optional smtp_user &
- # smtp_pass variables should be used
- #
- #email:
- # enable_notifs: false
- # smtp_host: "localhost"
- # smtp_port: 25 # SSL: 465, STARTTLS: 587
- # smtp_user: "exampleusername"
- # smtp_pass: "examplepassword"
- # require_transport_security: false
- #
- # # notif_from defines the "From" address to use when sending emails.
- # # It must be set if email sending is enabled.
- # #
- # # The placeholder '%(app)s' will be replaced by the application name,
- # # which is normally 'app_name' (below), but may be overridden by the
- # # Matrix client application.
- # #
- # # Note that the placeholder must be written '%(app)s', including the
- # # trailing 's'.
- # #
- # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
- #
- # # app_name defines the default value for '%(app)s' in notif_from. It
- # # defaults to 'Matrix'.
- # #
- # #app_name: my_branded_matrix_server
- #
- # # Enable email notifications by default
- # #
- # notif_for_new_users: true
- #
- # # Defining a custom URL for Riot is only needed if email notifications
- # # should contain links to a self-hosted installation of Riot; when set
- # # the "app_name" setting is ignored
- # #
- # riot_base_url: "http://localhost/riot"
- #
- # # Configure the time that a validation email or text message code
- # # will expire after sending
- # #
- # # This is currently used for password resets
- # #
- # #validation_token_lifetime: 1h
- #
- # # Template directory. All template files should be stored within this
- # # directory. If not set, default templates from within the Synapse
- # # package will be used
- # #
- # # For the list of default templates, please see
- # # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
- # #
- # #template_dir: res/templates
- #
- # # Templates for email notifications
- # #
- # notif_template_html: notif_mail.html
- # notif_template_text: notif_mail.txt
- #
- # # Templates for account expiry notices
- # #
- # expiry_template_html: notice_expiry.html
- # expiry_template_text: notice_expiry.txt
- #
- # # Templates for password reset emails sent by the homeserver
- # #
- # #password_reset_template_html: password_reset.html
- # #password_reset_template_text: password_reset.txt
- #
- # # Templates for registration emails sent by the homeserver
- # #
- # #registration_template_html: registration.html
- # #registration_template_text: registration.txt
- #
- # # Templates for validation emails sent by the homeserver when adding an email to
- # # your user account
- # #
- # #add_threepid_template_html: add_threepid.html
- # #add_threepid_template_text: add_threepid.txt
- #
- # # Templates for password reset success and failure pages that a user
- # # will see after attempting to reset their password
- # #
- # #password_reset_template_success_html: password_reset_success.html
- # #password_reset_template_failure_html: password_reset_failure.html
- #
- # # Templates for registration success and failure pages that a user
- # # will see after attempting to register using an email or phone
- # #
- # #registration_template_success_html: registration_success.html
- # #registration_template_failure_html: registration_failure.html
+ return """\
+ # Configuration for sending emails from Synapse.
#
- # # Templates for success and failure pages that a user will see after attempting
- # # to add an email or phone to their account
- # #
- # #add_threepid_success_html: add_threepid_success.html
- # #add_threepid_failure_html: add_threepid_failure.html
+ email:
+ # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'.
+ #
+ #smtp_host: mail.server
+
+ # The port on the mail server for outgoing SMTP. Defaults to 25.
+ #
+ #smtp_port: 587
+
+ # Username/password for authentication to the SMTP server. By default, no
+ # authentication is attempted.
+ #
+ # smtp_user: "exampleusername"
+ # smtp_pass: "examplepassword"
+
+ # Uncomment the following to require TLS transport security for SMTP.
+ # By default, Synapse will connect over plain text, and will then switch to
+ # TLS via STARTTLS *if the SMTP server supports it*. If this option is set,
+ # Synapse will refuse to connect unless the server supports STARTTLS.
+ #
+ #require_transport_security: true
+
+ # Enable sending emails for messages that the user has missed
+ #
+ #enable_notifs: false
+
+ # notif_from defines the "From" address to use when sending emails.
+ # It must be set if email sending is enabled.
+ #
+ # The placeholder '%(app)s' will be replaced by the application name,
+ # which is normally 'app_name' (below), but may be overridden by the
+ # Matrix client application.
+ #
+ # Note that the placeholder must be written '%(app)s', including the
+ # trailing 's'.
+ #
+ #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
+
+ # app_name defines the default value for '%(app)s' in notif_from. It
+ # defaults to 'Matrix'.
+ #
+ #app_name: my_branded_matrix_server
+
+ # Uncomment the following to disable automatic subscription to email
+ # notifications for new users. Enabled by default.
+ #
+ #notif_for_new_users: false
+
+ # Custom URL for client links within the email notifications. By default
+ # links will be based on "https://matrix.to".
+ #
+ # (This setting used to be called riot_base_url; the old name is still
+ # supported for backwards-compatibility but is now deprecated.)
+ #
+ #client_base_url: "http://localhost/riot"
+
+ # Configure the time that a validation email will expire after sending.
+ # Defaults to 1h.
+ #
+ #validation_token_lifetime: 15m
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, default templates from within the Synapse package will be used.
+ #
+ # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
+ # If you *do* uncomment it, you will need to make sure that all the templates
+ # below are in the directory.
+ #
+ # Synapse will look for the following templates in this directory:
+ #
+ # * The contents of email notifications of missed events: 'notif_mail.html' and
+ # 'notif_mail.txt'.
+ #
+ # * The contents of account expiry notice emails: 'notice_expiry.html' and
+ # 'notice_expiry.txt'.
+ #
+ # * The contents of password reset emails sent by the homeserver:
+ # 'password_reset.html' and 'password_reset.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in the password reset email: 'password_reset_success.html' and
+ # 'password_reset_failure.html'
+ #
+ # * The contents of address verification emails sent during registration:
+ # 'registration.html' and 'registration.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in an address verification email sent during registration:
+ # 'registration_success.html' and 'registration_failure.html'
+ #
+ # * The contents of address verification emails sent when an address is added
+ # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in an address verification email sent when an address is added
+ # to a Matrix account: 'add_threepid_success.html' and
+ # 'add_threepid_failure.html'
+ #
+ # You can see the default templates at:
+ # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
+ #
+ #template_dir: "res/templates"
"""
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 0910958649..6f2b3a7faa 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -35,7 +35,7 @@ class PushConfig(Config):
# Now check for the one in the 'email' section and honour it,
# with a warning.
- push_config = config.get("email", {})
+ push_config = config.get("email") or {}
redact_content = push_config.get("redact_content")
if redact_content is not None:
print(
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ee9614c5f7..b873995a49 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -27,6 +27,8 @@ class AccountValidityConfig(Config):
section = "accountvalidity"
def __init__(self, config, synapse_config):
+ if config is None:
+ return
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config
@@ -159,23 +161,6 @@ class RegistrationConfig(Config):
# Optional account validity configuration. This allows for accounts to be denied
# any request after a given period.
#
- # ``enabled`` defines whether the account validity feature is enabled. Defaults
- # to False.
- #
- # ``period`` allows setting the period after which an account is valid
- # after its registration. When renewing the account, its validity period
- # will be extended by this amount of time. This parameter is required when using
- # the account validity feature.
- #
- # ``renew_at`` is the amount of time before an account's expiry date at which
- # Synapse will send an email to the account's email address with a renewal link.
- # This needs the ``email`` and ``public_baseurl`` configuration sections to be
- # filled.
- #
- # ``renew_email_subject`` is the subject of the email sent out with the renewal
- # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
- # from the ``email`` section.
- #
# Once this feature is enabled, Synapse will look for registered users without an
# expiration date at startup and will add one to every account it found using the
# current settings at that time.
@@ -186,21 +171,55 @@ class RegistrationConfig(Config):
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10%% of the validity period.
#
- #account_validity:
- # enabled: true
- # period: 6w
- # renew_at: 1w
- # renew_email_subject: "Renew your %%(app)s account"
- # # Directory in which Synapse will try to find the HTML files to serve to the
- # # user when trying to renew an account. Optional, defaults to
- # # synapse/res/templates.
- # template_dir: "res/templates"
- # # HTML to be displayed to the user after they successfully renewed their
- # # account. Optional.
- # account_renewed_html_path: "account_renewed.html"
- # # HTML to be displayed when the user tries to renew an account with an invalid
- # # renewal token. Optional.
- # invalid_token_html_path: "invalid_token.html"
+ account_validity:
+ # The account validity feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # The period after which an account is valid after its registration. When
+ # renewing the account, its validity period will be extended by this amount
+ # of time. This parameter is required when using the account validity
+ # feature.
+ #
+ #period: 6w
+
+ # The amount of time before an account's expiry date at which Synapse will
+ # send an email to the account's email address with a renewal link. By
+ # default, no such emails are sent.
+ #
+ # If you enable this setting, you will also need to fill out the 'email' and
+ # 'public_baseurl' configuration sections.
+ #
+ #renew_at: 1w
+
+ # The subject of the email sent out with the renewal link. '%%(app)s' can be
+ # used as a placeholder for the 'app_name' parameter from the 'email'
+ # section.
+ #
+ # Note that the placeholder must be written '%%(app)s', including the
+ # trailing 's'.
+ #
+ # If this is not set, a default value is used.
+ #
+ #renew_email_subject: "Renew your %%(app)s account"
+
+ # Directory in which Synapse will try to find templates for the HTML files to
+ # serve to the user when trying to renew an account. If not set, default
+ # templates from within the Synapse package will be used.
+ #
+ #template_dir: "res/templates"
+
+ # File within 'template_dir' giving the HTML to be displayed to the user after
+ # they successfully renewed their account. If not set, default text is used.
+ #
+ #account_renewed_html_path: "account_renewed.html"
+
+ # File within 'template_dir' giving the HTML to be displayed when the user
+ # tries to renew an account with an invalid renewal token. If not set,
+ # default text is used.
+ #
+ #invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in.
#
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index b91414aa35..423c158b11 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -121,6 +121,7 @@ class SAML2Config(Config):
required_methods = [
"get_saml_attributes",
"saml_response_to_user_attributes",
+ "get_remote_user_id",
]
missing_methods = [
method
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a44baea365..9ea85e93e6 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Union
from six import iteritems
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import StateMap
@attr.s(slots=True)
@@ -106,13 +107,11 @@ class EventContext:
_state_group = attr.ib(default=None, type=Optional[int])
state_group_before_event = attr.ib(default=None, type=Optional[int])
prev_group = attr.ib(default=None, type=Optional[int])
- delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+ delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
app_service = attr.ib(default=None, type=Optional[ApplicationService])
- _current_state_ids = attr.ib(
- default=None, type=Optional[Dict[Tuple[str, str], str]]
- )
- _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+ _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+ _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
@staticmethod
def with_state(
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index ced4925a98..174f6e42be 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object):
def federation_ack(self, token):
self._clear_queue_before_pos(token)
- def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+ async def get_replication_rows(
+ self, from_token, to_token, limit, federation_ack=None
+ ):
"""Get rows to be sent over federation between the two tokens
Args:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4ebb0e8bc0..36c83c3027 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -21,6 +21,7 @@ from prometheus_client import Counter
from twisted.internet import defer
+import synapse
import synapse.metrics
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
@@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter(
class FederationSender(object):
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
self.server_name = hs.hostname
@@ -482,7 +483,20 @@ class FederationSender(object):
def send_device_messages(self, destination):
if destination == self.server_name:
- logger.info("Not sending device update to ourselves")
+ logger.warning("Not sending device update to ourselves")
+ return
+
+ self._get_per_destination_queue(destination).attempt_new_transaction()
+
+ def wake_destination(self, destination: str):
+ """Called when we want to retry sending transactions to a remote.
+
+ This is mainly useful if the remote server has been down and we think it
+ might have come back.
+ """
+
+ if destination == self.server_name:
+ logger.warning("Not waking up ourselves")
return
self._get_per_destination_queue(destination).attempt_new_transaction()
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index a5b36b1827..5012aaea35 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
+from synapse.types import StateMap
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver.
@@ -77,7 +78,7 @@ class PerDestinationQueue(object):
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
- self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
+ self._pending_edus_keyed = {} # type: StateMap[Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b4cbf23394..d8cf9ed299 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -44,6 +44,7 @@ from synapse.logging.opentracing import (
tags,
whitelisted_homeserver,
)
+from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
@@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object):
- def __init__(self, hs):
+ def __init__(self, hs: HomeServer):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.notifer = hs.get_notifier()
+
+ self.replication_client = None
+ if hs.config.worker.worker_app:
+ self.replication_client = hs.get_tcp_replication()
# A method just so we can pass 'self' as the authenticator to the Servlets
async def authenticate_request(self, request, content):
@@ -166,6 +172,17 @@ class Authenticator(object):
try:
logger.info("Marking origin %r as up", origin)
await self.store.set_destination_retry_timings(origin, None, 0, 0)
+
+ # Inform the relevant places that the remote server is back up.
+ self.notifer.notify_remote_server_up(origin)
+ if self.replication_client:
+ # If we're on a worker we try and inform master about this. The
+ # replication client doesn't hook into the notifier to avoid
+ # infinite loops where we send a `REMOTE_SERVER_UP` command to
+ # master, which then echoes it back to us which in turn pokes
+ # the notifier.
+ self.replication_client.send_remote_server_up(origin)
+
except Exception:
logger.exception("Error resetting retry timings on %s", origin)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 76d18a8ba8..60a7c938bc 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,9 +14,11 @@
# limitations under the License.
import logging
+from typing import List
from synapse.api.constants import Membership
-from synapse.types import RoomStreamToken
+from synapse.events import FrozenEvent
+from synapse.types import RoomStreamToken, StateMap
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -134,7 +136,7 @@ class AdminHandler(BaseHandler):
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
- rooms = await self.store.get_rooms_for_user_where_membership_is(
+ rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
@@ -259,35 +261,26 @@ class ExfiltrationWriter(object):
"""Interface used to specify how to write exported data.
"""
- def write_events(self, room_id, events):
+ def write_events(self, room_id: str, events: List[FrozenEvent]):
"""Write a batch of events for a room.
-
- Args:
- room_id (str)
- events (list[FrozenEvent])
"""
pass
- def write_state(self, room_id, event_id, state):
+ def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
-
- Args:
- room_id (str)
- event_id (str)
- state (dict[tuple[str, str], FrozenEvent])
"""
pass
- def write_invite(self, room_id, event, state):
+ def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
"""Write an invite for the room, with associated invite state.
Args:
- room_id (str)
- event (FrozenEvent)
- state (dict[tuple[str, str], dict]): A subset of the state at the
+ room_id
+ event
+ state: A subset of the state at the
invite, with a subset of the event keys (type, state_key
content and sender)
"""
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4426967f88..2afb390a92 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
- pending_invites = await self.store.get_invited_rooms_for_user(user_id)
+ pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
for room in pending_invites:
try:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 61b6713c88..d4f9a792fc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -89,7 +89,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
- auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
def shortstr(iterable, maxitems=5):
@@ -352,9 +352,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps = list(
- ours.values()
- ) # type: list[dict[tuple[str, str], str]]
+ state_maps = list(ours.values()) # type: list[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+ auth_events: Optional[StateMap[EventBase]],
backfilled: bool,
):
"""
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 44ec3e66ae..2e6755f19c 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
- room_list = await self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9cab2adbfb..9f50196ea7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ Requester,
+ RoomAlias,
+ RoomID,
+ RoomStreamToken,
+ StateMap,
+ StreamToken,
+ UserID,
+)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
@@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def _update_upgraded_room_pls(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self,
+ requester: Requester,
+ old_room_id: str,
+ new_room_id: str,
+ old_room_state: StateMap[str],
):
"""Send updated power levels in both rooms after an upgrade
Args:
- requester (synapse.types.Requester): the user requesting the upgrade
- old_room_id (str): the id of the room to be replaced
- new_room_id (str): the id of the replacement room
- old_room_state (dict[tuple[str, str], str]): the state map for the old room
+ requester: the user requesting the upgrade
+ old_room_id: the id of the room to be replaced
+ new_room_id: the id of the replacement room
+ old_room_state: the state map for the old room
Returns:
Deferred
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 03bb52ccfb..15e8aa5249 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -690,7 +690,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
- invite = yield self.store.get_invite_for_user_in_room(
+ invite = yield self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id
)
if invite:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 107f97032b..7f411b53b9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -32,6 +32,7 @@ from synapse.types import (
mxid_localpart_allowed_characters,
)
from synapse.util.async_helpers import Linearizer
+from synapse.util.iterutils import chunk_seq
logger = logging.getLogger(__name__)
@@ -132,17 +133,28 @@ class SamlHandler:
logger.warning("SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed")
- logger.info("SAML2 response: %s", saml2_auth.origxml)
- logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+ logger.debug("SAML2 response: %s", saml2_auth.origxml)
+ for assertion in saml2_auth.assertions:
+ # kibana limits the length of a log field, whereas this is all rather
+ # useful, so split it up.
+ count = 0
+ for part in chunk_seq(str(assertion), 10000):
+ logger.info(
+ "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
+ )
+ count += 1
- try:
- remote_user_id = saml2_auth.ava["uid"][0]
- except KeyError:
- logger.warning("SAML2 response lacks a 'uid' attestation")
- raise SynapseError(400, "'uid' not in SAML2 response")
+ logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise Exception("Failed to extract remote user id from SAML response")
+
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -279,6 +291,20 @@ class DefaultSamlMappingProvider(object):
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
+ self._grandfathered_mxid_source_attribute = (
+ module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+ )
+
+ def get_remote_user_id(
+ self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+ ):
+ """Extracts the remote user id from the SAML response"""
+ try:
+ return saml_response.ava["uid"][0]
+ except KeyError:
+ logger.warning("SAML2 response lacks a 'uid' attestation")
+ raise SynapseError(400, "'uid' not in SAML2 response")
+
def saml_response_to_user_attributes(
self,
saml_response: saml2.response.AuthnResponse,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ef750d1497..110097eab9 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -179,7 +179,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
- rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2d3b8ba73c..cd95f85e3f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1662,7 +1662,7 @@ class SyncHandler(object):
Membership.BAN,
)
- room_list = await self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=membership_list
)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index b635c339ed..d5ca9cb07b 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -257,7 +257,7 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
- def get_all_typing_updates(self, last_id, current_id):
+ async def get_all_typing_updates(self, last_id, current_id):
if last_id == current_id:
return []
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 943d12c907..04bc2385a2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import cgi
import collections
+import html
import http.client
import logging
import types
@@ -36,6 +36,7 @@ import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
+ RedirectException,
SynapseError,
UnrecognizedRequestError,
)
@@ -153,14 +154,18 @@ def _return_html_error(f, request):
Args:
f (twisted.python.failure.Failure):
- request (twisted.web.iweb.IRequest):
+ request (twisted.web.server.Request):
"""
if f.check(CodeMessageException):
cme = f.value
code = cme.code
msg = cme.msg
- if isinstance(cme, SynapseError):
+ if isinstance(cme, RedirectException):
+ logger.info("%s redirect to %s", request, cme.location)
+ request.setHeader(b"location", cme.location)
+ request.cookies.extend(cme.cookies)
+ elif isinstance(cme, SynapseError):
logger.info("%s SynapseError: %s - %s", request, code, msg)
else:
logger.error(
@@ -178,7 +183,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
+ body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 305b9b0178..d680ee95e1 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,18 +17,26 @@ import logging
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import UserID
+"""
+This package defines the 'stable' API which can be used by extension modules which
+are loaded into Synapse.
+"""
+
+__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
+
logger = logging.getLogger(__name__)
class ModuleApi(object):
- """A proxy object that gets passed to password auth providers so they
+ """A proxy object that gets passed to various plugin modules so they
can register new users etc if necessary.
"""
def __init__(self, hs, auth_handler):
- self.hs = hs
+ self._hs = hs
self._store = hs.get_datastore()
self._auth = hs.get_auth()
@@ -64,7 +73,7 @@ class ModuleApi(object):
"""
if username.startswith("@"):
return username
- return UserID(username, self.hs.hostname).to_string()
+ return UserID(username, self._hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
@@ -111,10 +120,14 @@ class ModuleApi(object):
displayname (str|None): The displayname of the new user.
emails (List[str]): Emails to bind to the new user.
+ Raises:
+ SynapseError if there is an error performing the registration. Check the
+ 'errcode' property for more information on the reason for failure
+
Returns:
Deferred[str]: user_id
"""
- return self.hs.get_registration_handler().register_user(
+ return self._hs.get_registration_handler().register_user(
localpart=localpart, default_display_name=displayname, bind_emails=emails
)
@@ -131,12 +144,34 @@ class ModuleApi(object):
Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
"""
- return self.hs.get_registration_handler().register_device(
+ return self._hs.get_registration_handler().register_device(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
)
+ def record_user_external_id(
+ self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
+ ) -> defer.Deferred:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ return self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
+
+ def generate_short_term_login_token(
+ self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ ) -> str:
+ """Generate a login token suitable for m.login.token authentication"""
+ return self._hs.get_macaroon_generator().generate_short_term_login_token(
+ user_id, duration_in_ms
+ )
+
@defer.inlineCallbacks
def invalidate_access_token(self, access_token):
"""Invalidate an access token for a user
@@ -157,7 +192,7 @@ class ModuleApi(object):
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
- yield self.hs.get_device_handler().delete_device(user_id, device_id)
+ yield self._hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
new file mode 100644
index 0000000000..b15441772c
--- /dev/null
+++ b/synapse/module_api/errors.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Exception types which are exposed as part of the stable module API"""
+
+from synapse.api.errors import RedirectException, SynapseError # noqa: F401
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 5f5f765bea..6132727cbd 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,11 +15,13 @@
import logging
from collections import namedtuple
+from typing import Callable, List
from prometheus_client import Counter
from twisted.internet import defer
+import synapse.server
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
@@ -154,7 +156,7 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
@@ -164,7 +166,12 @@ class Notifier(object):
self.store = hs.get_datastore()
self.pending_new_room_events = []
- self.replication_callbacks = []
+ # Called when there are new things to stream over replication
+ self.replication_callbacks = [] # type: List[Callable[[], None]]
+
+ # Called when remote servers have come back online after having been
+ # down.
+ self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]]
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
@@ -205,7 +212,7 @@ class Notifier(object):
"synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
)
- def add_replication_callback(self, cb):
+ def add_replication_callback(self, cb: Callable[[], None]):
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
@@ -213,6 +220,12 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
+ def add_remote_server_up_callback(self, cb: Callable[[str], None]):
+ """Add a callback that will be called when synapse detects a server
+ has been
+ """
+ self.remote_server_up_callbacks.append(cb)
+
def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
):
@@ -522,3 +535,15 @@ class Notifier(object):
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
+
+ def notify_remote_server_up(self, server: str):
+ """Notify any replication that a remote server has come back up
+ """
+ # We call federation_sender directly rather than registering as a
+ # callback as a) we already have a reference to it and b) it introduces
+ # circular dependencies.
+ if self.federation_sender:
+ self.federation_sender.wake_destination(server)
+
+ for cb in self.remote_server_up_callbacks:
+ cb(server)
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index de5c101a58..5dae4648c0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -21,7 +21,7 @@ from synapse.storage import Storage
@defer.inlineCallbacks
def get_badge_count(store, user_id):
- invites = yield store.get_invited_rooms_for_user(user_id)
+ invites = yield store.get_invited_rooms_for_local_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29f35b9915..3aa6cb8b96 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -152,7 +152,7 @@ class SlavedEventStore(
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
- self.get_invited_rooms_for_user.invalidate((state_key,))
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index aa7fd90e26..fc06a7b053 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
@@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
-
- Returns:
- Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- return self.store.process_replication_rows(stream_name, token, rows)
+ self.store.process_replication_rows(stream_name, token, rows)
- def on_position(self, stream_name, token):
+ async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
- return self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
@@ -146,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
if d:
d.callback(data)
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index cbb36b9acf..451671412d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -387,6 +387,20 @@ class UserIpCommand(Command):
)
+class RemoteServerUpCommand(Command):
+ """Sent when a worker has detected that a remote server is no longer
+ "down" and retry timings should be reset.
+
+ If sent from a client the server will relay to all other workers.
+
+ Format::
+
+ REMOTE_SERVER_UP <server>
+ """
+
+ NAME = "REMOTE_SERVER_UP"
+
+
_COMMANDS = (
ServerCommand,
RdataCommand,
@@ -401,6 +415,7 @@ _COMMANDS = (
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
+ RemoteServerUpCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
SyncCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
# The commands the client is allowed to send
@@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index db0353c996..131e5acb09 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -76,17 +76,17 @@ from synapse.replication.tcp.commands import (
PingCommand,
PositionCommand,
RdataCommand,
+ RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
+from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
-from .streams import STREAMS_MAP
-
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -241,19 +241,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
- def handle_command(self, cmd):
+ async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>
+ By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
- cmd (synapse.replication.tcp.commands.Command): received command
-
- Returns:
- Deferred
+ cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
- return handler(cmd)
+ await handler(cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -326,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- def on_PING(self, line):
+ async def on_PING(self, line):
self.received_ping = True
- def on_ERROR(self, cmd):
+ async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -429,16 +426,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
- def on_NAME(self, cmd):
+ async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- def on_USER_SYNC(self, cmd):
- return self.streamer.on_user_sync(
+ async def on_USER_SYNC(self, cmd):
+ await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- def on_REPLICATE(self, cmd):
+ async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
@@ -449,23 +446,26 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
for stream in iterkeys(self.streamer.streams_by_name)
]
- return make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
- return self.subscribe_to_stream(stream_name, token)
+ await self.subscribe_to_stream(stream_name, token)
- def on_FEDERATION_ACK(self, cmd):
- return self.streamer.federation_ack(cmd.token)
+ async def on_FEDERATION_ACK(self, cmd):
+ self.streamer.federation_ack(cmd.token)
- def on_REMOVE_PUSHER(self, cmd):
- return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+ async def on_REMOVE_PUSHER(self, cmd):
+ await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
- def on_INVALIDATE_CACHE(self, cmd):
- return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ async def on_INVALIDATE_CACHE(self, cmd):
+ self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
- def on_USER_IP(self, cmd):
- return self.streamer.on_user_ip(
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.streamer.on_remote_server_up(cmd.data)
+
+ async def on_USER_IP(self, cmd):
+ self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
@@ -474,8 +474,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
- @defer.inlineCallbacks
- def subscribe_to_stream(self, stream_name, token):
+ async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
@@ -487,7 +486,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try:
# Get missing updates
- updates, current_token = yield self.streamer.get_stream_updates(
+ updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
@@ -560,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def send_sync(self, data):
self.send_command(SyncCommand(data))
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
@@ -572,7 +574,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
@@ -580,14 +582,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
-
- Returns:
- Deferred|None
"""
raise NotImplementedError()
@abc.abstractmethod
- def on_position(self, stream_name, token):
+ async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()
@@ -597,6 +596,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
subscribe to streams.
@@ -676,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
if not self.streams_connecting:
self.handler.finished_connecting()
- def on_SERVER(self, cmd):
+ async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
- def on_RDATA(self, cmd):
+ async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -701,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- return self.handler.on_rdata(stream_name, cmd.token, rows)
+ await self.handler.on_rdata(stream_name, cmd.token, rows)
- def on_POSITION(self, cmd):
+ async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- return self.handler.on_position(cmd.stream_name, cmd.token)
+ await self.handler.on_position(cmd.stream_name, cmd.token)
+
+ async def on_SYNC(self, cmd):
+ self.handler.on_sync(cmd.data)
- def on_SYNC(self, cmd):
- return self.handler.on_sync(cmd.data)
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index cbfdaf5773..6ebf944f66 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,6 @@ from six import itervalues
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
@@ -121,6 +120,7 @@ class ReplicationStreamer(object):
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -155,8 +155,7 @@ class ReplicationStreamer(object):
run_as_background_process("replication_notifier", self._run_notifier_loop)
- @defer.inlineCallbacks
- def _run_notifier_loop(self):
+ async def _run_notifier_loop(self):
self.is_looping = True
try:
@@ -185,7 +184,7 @@ class ReplicationStreamer(object):
continue
if self._replication_torture_level:
- yield self.clock.sleep(
+ await self.clock.sleep(
self._replication_torture_level / 1000.0
)
@@ -196,7 +195,7 @@ class ReplicationStreamer(object):
stream.upto_token,
)
try:
- updates, current_token = yield stream.get_updates()
+ updates, current_token = await stream.get_updates()
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
@@ -233,7 +232,7 @@ class ReplicationStreamer(object):
self.is_looping = False
@measure_func("repl.get_stream_updates")
- def get_stream_updates(self, stream_name, token):
+ async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -241,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return stream.get_updates_since(token)
+ return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
@@ -252,22 +251,20 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
- @defer.inlineCallbacks
- def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
- yield self.presence_handler.update_external_syncs_row(
+ await self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
- @defer.inlineCallbacks
- def on_remove_pusher(self, app_id, push_key, user_id):
+ async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
- yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
@@ -281,15 +278,24 @@ class ReplicationStreamer(object):
getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip")
- @defer.inlineCallbacks
- def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ async def on_user_ip(
+ self, user_id, access_token, ip, user_agent, device_id, last_seen
+ ):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
- yield self.store.insert_client_ip(
+ await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen
)
- yield self._server_notices_sender.on_user_ip(user_id)
+ await self._server_notices_sender.on_user_ip(user_id)
+
+ @measure_func("repl.on_remote_server_up")
+ def on_remote_server_up(self, server: str):
+ self.notifier.notify_remote_server_up(server)
+
+ def send_remote_server_up(self, server: str):
+ for conn in self.connections:
+ conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4ab0334fc1..e03e77199b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -19,8 +19,6 @@ import logging
from collections import namedtuple
from typing import Any
-from twisted.internet import defer
-
logger = logging.getLogger(__name__)
@@ -144,8 +142,7 @@ class Stream(object):
self.upto_token = self.current_token()
self.last_token = self.upto_token
- @defer.inlineCallbacks
- def get_updates(self):
+ async def get_updates(self):
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before),
until the `upto_token`
@@ -156,13 +153,12 @@ class Stream(object):
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
- updates, current_token = yield self.get_updates_since(self.last_token)
+ updates, current_token = await self.get_updates_since(self.last_token)
self.last_token = current_token
return updates, current_token
- @defer.inlineCallbacks
- def get_updates_since(self, from_token):
+ async def get_updates_since(self, from_token):
"""Like get_updates except allows specifying from when we should
stream updates
@@ -182,15 +178,16 @@ class Stream(object):
if from_token == current_token:
return [], current_token
+ logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
- rows = yield self.update_function(
+ rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
- rows = yield self.update_function(from_token, current_token)
+ rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
@@ -295,9 +292,8 @@ class PushRulesStream(Stream):
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
+ async def update_function(self, from_token, to_token, limit):
+ rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows]
@@ -413,9 +409,8 @@ class AccountDataStream(Stream):
super(AccountDataStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- global_results, room_results = yield self.store.get_all_updated_account_data(
+ async def update_function(self, from_token, to_token, limit):
+ global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 0843e5aa90..b3afabb8cd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,8 +19,6 @@ from typing import Tuple, Type
import attr
-from twisted.internet import defer
-
from ._base import Stream
@@ -122,16 +120,15 @@ class EventsStream(Stream):
super(EventsStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, current_token, limit=None):
- event_rows = yield self._store.get_all_new_forward_event_rows(
+ async def update_function(self, from_token, current_token, limit=None):
+ event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
event_updates = (
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
)
- state_rows = yield self._store.get_all_updated_current_state_deltas(
+ state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, current_token, limit
)
state_updates = (
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 7a256b6ecb..50e080673b 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -206,10 +206,6 @@ class AuthRestServlet(RestServlet):
return None
elif stagetype == LoginType.TERMS:
- if ("session" not in request.args or len(request.args["session"])) == 0:
- raise SynapseError(400, "No session supplied")
-
- session = request.args["session"][0]
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 2a477ad22e..3d0fefb4df 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -71,6 +71,8 @@ class VersionsRestServlet(RestServlet):
# Implements support for label-based filtering as described in
# MSC2326.
"org.matrix.label_based_filtering": True,
+ # Implements support for cross signing as described in MSC1756
+ "org.matrix.e2e_cross_signing": True,
},
},
)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index b5e0b57095..0731403047 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,5 @@
+import twisted.internet
+
import synapse.api.auth
import synapse.config.homeserver
import synapse.federation.sender
@@ -9,10 +11,12 @@ import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.message
+import synapse.handlers.presence
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
+import synapse.notifier
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -85,3 +89,11 @@ class HomeServer(object):
self,
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
pass
+ def get_notifier(self) -> synapse.notifier.Notifier:
+ pass
+ def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
+ pass
+ def get_clock(self) -> synapse.util.Clock:
+ pass
+ def get_reactor(self) -> twisted.internet.base.ReactorBase:
+ pass
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 2dac90578c..f7432c8d2f 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -105,7 +105,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
- rooms = yield self._store.get_rooms_for_user_where_membership_is(
+ rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
system_mxid = self._config.server_notices_mxid
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 5accc071ab..cacd0c0c2b 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional
from six import iteritems, itervalues
@@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store(
room_id: str,
room_version: str,
- state_sets: List[Dict[Tuple[str, str], str]],
+ state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index b2f9865f39..d6c34ce3b7 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,7 +15,7 @@
import hashlib
import logging
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional
from six import iteritems, iterkeys, itervalues
@@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks
def resolve_events_with_store(
room_id: str,
- state_sets: List[Dict[Tuple[str, str], str]],
+ state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable,
):
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 72fb8a6317..6216fdd204 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,7 +16,7 @@
import heapq
import itertools
import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
from six import iteritems, itervalues
@@ -27,6 +27,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events import EventBase
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
def resolve_events_with_store(
room_id: str,
room_version: str,
- state_sets: List[Dict[Tuple[str, str], str]],
+ state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
):
@@ -393,12 +394,12 @@ def _iterative_auth_checks(
room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
- base_state (dict[tuple[str, str], str]): The set of state to start with
+ base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
- Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+ Deferred[StateMap[str]]: Returns the final updated state
"""
resolved_state = base_state.copy()
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 54ed8574c4..bf91512daf 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
-from synapse.util import batch_iter
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 9a828231c4..f0a7962dd0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -33,13 +33,13 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.types import get_verify_key_from_cross_signing_key
-from synapse.util import batch_iter
from synapse.util.caches.descriptors import (
Cache,
cached,
cachedInlineCallbacks,
cachedList,
)
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 58f35d7f56..bb69c20448 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -43,9 +43,9 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.storage.database import Database
from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -128,6 +128,7 @@ class EventsStore(
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+ self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def _read_forward_extremities(self):
@@ -547,6 +548,34 @@ class EventsStore(
],
)
+ # Note: Do we really want to delete rows here (that we do not
+ # subsequently reinsert below)? While technically correct it means
+ # we have no record of the fact the user *was* a member of the
+ # room but got, say, state reset out of it.
+ if to_delete or to_insert:
+ txn.executemany(
+ "DELETE FROM local_current_membership"
+ " WHERE room_id = ? AND user_id = ?",
+ (
+ (room_id, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ if etype == EventTypes.Member and self.is_mine_id(state_key)
+ ),
+ )
+
+ if to_insert:
+ txn.executemany(
+ """INSERT INTO local_current_membership
+ (room_id, user_id, event_id, membership)
+ VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[1], ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+ ],
+ )
+
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id,
@@ -1724,6 +1753,7 @@ class EventsStore(
"local_invites",
"room_account_data",
"room_tags",
+ "local_current_membership",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 0cce5232f5..3b93e0597a 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -37,8 +37,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
from synapse.util.caches.descriptors import Cache
+from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index 6b12f5a75f..ba89c68c9f 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -23,8 +23,8 @@ from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
from synapse.storage.keys import FetchKeyResult
-from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index a2c83e0867..604c8b7ddd 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -17,8 +17,8 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
-from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 70ff5751b6..9acef7c950 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0]: row[1] for row in txn}
@cached()
- def get_invited_rooms_for_user(self, user_id):
- """ Get all the rooms the user is invited to
+ def get_invited_rooms_for_local_user(self, user_id):
+ """ Get all the rooms the *local* user is invited to
+
Args:
user_id (str): The user ID.
Returns:
A deferred list of RoomsForUser.
"""
- return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
+ return self.get_rooms_for_local_user_where_membership_is(
+ user_id, [Membership.INVITE]
+ )
@defer.inlineCallbacks
- def get_invite_for_user_in_room(self, user_id, room_id):
- """Gets the invite for the given user and room
+ def get_invite_for_local_user_in_room(self, user_id, room_id):
+ """Gets the invite for the given *local* user and room
Args:
user_id (str)
@@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
"""
- invites = yield self.get_invited_rooms_for_user(user_id)
+ invites = yield self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
@defer.inlineCallbacks
- def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
- """ Get all the rooms for this user where the membership for this user
+ def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
+ """ Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
@@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return defer.succeed(None)
rooms = yield self.db.runInteraction(
- "get_rooms_for_user_where_membership_is",
- self._get_rooms_for_user_where_membership_is_txn,
+ "get_rooms_for_local_user_where_membership_is",
+ self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
membership_list,
)
@@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
- def _get_rooms_for_user_where_membership_is_txn(
+ def _get_rooms_for_local_user_where_membership_is_txn(
self, txn, user_id, membership_list
):
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
+ % (user_id,),
+ )
- do_invite = Membership.INVITE in membership_list
- membership_list = [m for m in membership_list if m != Membership.INVITE]
-
- results = []
- if membership_list:
- if self._current_state_events_membership_up_to_date:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "c.membership", membership_list
- )
- sql = """
- SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND state_key = ?
- AND %s
- """ % (
- clause,
- )
- else:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "m.membership", membership_list
- )
- sql = """
- SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (room_id, event_id)
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND state_key = ?
- AND %s
- """ % (
- clause,
- )
-
- txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "c.membership", membership_list
+ )
- if do_invite:
- sql = (
- "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
- " FROM local_invites as i"
- " INNER JOIN events as e USING (event_id)"
- " WHERE invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
+ sql = """
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ FROM local_current_membership AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND %s
+ """ % (
+ clause,
+ )
- txn.execute(sql, (user_id,))
- results.extend(
- RoomsForUser(
- room_id=r["room_id"],
- sender=r["inviter"],
- event_id=r["event_id"],
- stream_ordering=r["stream_ordering"],
- membership=Membership.INVITE,
- )
- for r in self.db.cursor_to_dict(txn)
- )
+ txn.execute(sql, (user_id, *args))
+ results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
return results
- @cachedInlineCallbacks(max_entries=500000, iterable=True)
+ @cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id):
- """Returns a set of room_ids the user is currently joined to
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
Args:
user_id (str)
@@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore):
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
"""
- rooms = yield self.get_rooms_for_user_where_membership_is(
- user_id, membership_list=[Membership.JOIN]
- )
- return frozenset(
- GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
- for r in rooms
+ return self.db.runInteraction(
+ "get_rooms_for_user_with_stream_ordering",
+ self._get_rooms_for_user_with_stream_ordering_txn,
+ user_id,
)
+ def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+ # We use `current_state_events` here and not `local_current_membership`
+ # as a) this gets called with remote users and b) this only gets called
+ # for rooms the server is participating in.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND c.membership = ?
+ """
+ else:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (room_id, event_id)
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND m.membership = ?
+ """
+
+ txn.execute(sql, (user_id, Membership.JOIN))
+ results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+
+ return results
+
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
- """Returns a set of room_ids the user is currently joined to
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
@@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
event.internal_metadata.stream_ordering,
)
txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
@@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
),
)
+ # We also update the `local_current_membership` table with
+ # latest invite info. This will usually get updated by the
+ # `current_state_events` handling, unless its an outlier.
+ if event.internal_metadata.is_outlier():
+ # This should only happen for out of band memberships, so
+ # we add a paranoia check.
+ assert event.internal_metadata.is_out_of_band_membership()
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={
+ "room_id": event.room_id,
+ "user_id": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
+
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
@@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
+ # We also clear this entry from `local_current_membership`.
+ # Ideally we'd point to a leave event, but we don't have one, so
+ # nevermind.
+ self.db.simple_delete_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ )
+
with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
new file mode 100644
index 0000000000..601c236c4a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# We create a new table called `local_current_membership` that stores the latest
+# membership state of local users in rooms, which helps track leaves/bans/etc
+# even if the server has left the room (and so has deleted the room from
+# `current_state_events`). This will also include outstanding invites for local
+# users for rooms the server isn't in.
+#
+# If the server isn't and hasn't been in the room then it will only include
+# outsstanding invites, and not e.g. pre-emptive bans of local users.
+#
+# If the server later rejoins a room `local_current_membership` can simply be
+# replaced with the new current state of the room (which results in the
+# equivalent behaviour as if the server had remained in the room).
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+ # We need to do the insert in `run_upgrade` section as we don't have access
+ # to `config` in `run_create`.
+
+ # This upgrade may take a bit of time for large servers (e.g. one minute for
+ # matrix.org) but means we avoid a lots of book keeping required to do it as
+ # a background update.
+
+ # We check if the `current_state_events.membership` is up to date by
+ # checking if the relevant background update has finished. If it has
+ # finished we can avoid doing a join against `room_memberships`, which
+ # speesd things up.
+ cur.execute(
+ """SELECT 1 FROM background_updates
+ WHERE update_name = 'current_state_events_membership'
+ """
+ )
+ current_state_membership_up_to_date = not bool(cur.fetchone())
+
+ # Cheekily drop and recreate indices, as that is faster.
+ cur.execute("DROP INDEX local_current_membership_idx")
+ cur.execute("DROP INDEX local_current_membership_room_idx")
+
+ if current_state_membership_up_to_date:
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, c.membership
+ FROM current_state_events AS c
+ WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ?
+ """
+ else:
+ # We can't rely on the membership column, so we need to join against
+ # `room_memberships`.
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, r.membership
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS r USING (event_id)
+ WHERE type = 'm.room.member' and state_key like '%' || ?
+ """
+ cur.execute(sql, (config.server_name,))
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ """
+ CREATE TABLE local_current_membership (
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ membership TEXT NOT NULL
+ )"""
+ )
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d07440e3ed..33bebd1c48 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
+ def get_filtered_current_state_ids(
+ self, room_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Args:
- room_id (str)
- state_filter (StateFilter): The state filter used to fetch state
+ room_id
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
- event ID.
+ defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index d53695f238..c4ee9b7ccb 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
from six import iteritems
from six.moves import range
@@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
+from synapse.types import StateMap
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, state_filter):
- """Returns the state groups for a given set of groups, filtering on
- types of state events.
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
+ """Returns the state groups for a given set of groups from the
+ database, filtering on types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
results = {}
@@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
- groups (iterable[int]): list of state groups for which we want
+ groups: list of state groups for which we want
to get the state.
- state_filter (StateFilter): The state filter used to fetch state
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
- def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+ def _get_state_for_groups_using_cache(
+ self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+ ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- cache (DictionaryCache): the cache of group ids to state dicts which
- we will pass through - either the normal state cache or the specific
- members state cache.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ groups: list of state groups for which we want to get the state.
+ cache: the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the
+ specific members state cache.
+ state_filter: The state filter used to fetch state from the
+ database.
Returns:
- tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- of entries in the cache, and the state group ids either missing
- from the cache or incomplete.
+ Tuple of dict of state_group_id to state map of entries in the
+ cache, and the state group ids either missing from the cache or
+ incomplete.
"""
results = {}
incomplete_groups = set()
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e70026b80a..e86984cd50 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 56
+SCHEMA_VERSION = 57
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index d6a7bd7834..fdc0abf5cf 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -34,7 +34,7 @@ class PurgeEventsStorage(object):
"""
state_groups_to_delete = yield self.stores.main.purge_room(room_id)
- yield self.stores.main.purge_room_state(room_id, state_groups_to_delete)
+ yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
@defer.inlineCallbacks
def purge_history(self, room_id, token, delete_local_events):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cbeb586014..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Iterable, List, TypeVar
from six import iteritems, itervalues
@@ -22,9 +23,13 @@ import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
+# Used for generic functions below
+T = TypeVar("T")
+
@attr.s(slots=True)
class StateFilter(object):
@@ -233,14 +238,14 @@ class StateFilter(object):
return len(self.concrete_types())
- def filter_state(self, state_dict):
+ def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
Args:
- state (dict[tuple[str, str], Any]): The state map to filter
+ state: The state map to filter
Returns:
- dict[tuple[str, str], Any]: The filtered state map
+ The filtered state map
"""
if self.is_full():
return dict(state_dict)
@@ -333,12 +338,12 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores
- def get_state_group_delta(self, state_group):
+ def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+ Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
"""
@@ -353,7 +358,7 @@ class StateGroupStorage(object):
event_ids (iterable[str]): ids of the events
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
+ Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
@@ -410,17 +415,18 @@ class StateGroupStorage(object):
for group, event_id_map in iteritems(group_to_ids)
}
- def _get_state_groups_from_groups(self, groups, state_filter):
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@@ -519,7 +525,9 @@ class StateGroupStorage(object):
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -529,8 +537,7 @@ class StateGroupStorage(object):
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
diff --git a/synapse/types.py b/synapse/types.py
index cd996c0b5a..65e4d8c181 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -17,6 +17,7 @@ import re
import string
import sys
from collections import namedtuple
+from typing import Dict, Tuple, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
- from typing import Sized, Iterable, Container, TypeVar
+ from typing import Sized, Iterable, Container
T_co = TypeVar("T_co", covariant=True)
@@ -36,6 +37,12 @@ else:
__slots__ = ()
+# Define a state map type from type/state_key to T (usually an event ID or
+# event)
+T = TypeVar("T")
+StateMap = Dict[Tuple[str, str], T]
+
+
class Requester(
namedtuple(
"Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 7856353002..60f0de70f7 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,7 +15,6 @@
import logging
import re
-from itertools import islice
import attr
@@ -107,22 +106,6 @@ class Clock(object):
raise
-def batch_iter(iterable, size):
- """batch an iterable up into tuples with a maximum size
-
- Args:
- iterable (iterable): the iterable to slice
- size (int): the maximum batch size
-
- Returns:
- an iterator over the chunks
- """
- # make sure we can deal with iterables like lists too
- sourceiter = iter(iterable)
- # call islice until it returns an empty tuple
- return iter(lambda: tuple(islice(sourceiter, size)), ())
-
-
def log_failure(failure, msg, consumeErrors=True):
"""Creates a function suitable for passing to `Deferred.addErrback` that
logs any failures that occur.
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
new file mode 100644
index 0000000000..06faeebe7f
--- /dev/null
+++ b/synapse/util/iterutils.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from itertools import islice
+from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+
+T = TypeVar("T")
+
+
+def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
+ """batch an iterable up into tuples with a maximum size
+
+ Args:
+ iterable (iterable): the iterable to slice
+ size (int): the maximum batch size
+
+ Returns:
+ an iterator over the chunks
+ """
+ # make sure we can deal with iterables like lists too
+ sourceiter = iter(iterable)
+ # call islice until it returns an empty tuple
+ return iter(lambda: tuple(islice(sourceiter, size)), ())
+
+
+ISeq = TypeVar("ISeq", bound=Sequence, covariant=True)
+
+
+def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
+ """Split the given sequence into chunks of the given size
+
+ The last chunk may be shorter than the given size.
+
+ If the input is empty, no chunks are returned.
+ """
+ return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 2705cbe5f8..bb62db4637 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -34,7 +34,7 @@ def load_module(provider):
provider_class = getattr(module, clz)
try:
- provider_config = provider_class.parse_config(provider["config"])
+ provider_config = provider_class.parse_config(provider.get("config"))
except Exception as e:
raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 758ee071a5..4cbe9784ed 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -32,8 +32,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
- user_id1 = "@user1:server"
- user_id2 = "@user2:server"
+ user_id1 = "@user1:test"
+ user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index b68e9fe082..b1b037006d 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -115,13 +115,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def test_invites(self):
self.persist(type="m.room.create", key="", creator=USER_ID)
- self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+ self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
- "get_invited_rooms_for_user",
+ "get_invited_rooms_for_local_user",
[USER_ID_2],
[
RoomsForUser(
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 1d14e77255..e96ad4ca4e 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -73,6 +73,6 @@ class TestReplicationClientHandler(object):
def finished_connecting(self):
pass
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 7a7e898843..f3b4a31e21 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -337,7 +337,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
"local_invites",
"room_account_data",
"room_tags",
- "state_groups",
+ # "state_groups", # Current impl leaves orphaned state groups around.
"state_groups_state",
):
count = self.get_success(
@@ -351,8 +351,6 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
- test_purge_room.skip = "Disabled because it's currently broken"
-
class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Test /quarantine_media admin API.
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 0f51895b81..c3facc00eb 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -285,7 +285,9 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
)
# Make sure the invite is here.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
@@ -293,12 +295,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
- pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
+ pending_invites = self.get_success(
+ store.get_invited_rooms_for_local_user(invitee_id)
+ )
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
- store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
+ store.get_rooms_for_local_user_where_membership_is(
+ invitee_id, [Membership.LEAVE]
+ )
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 661c1f88b9..9c13a13786 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -15,8 +15,6 @@
# limitations under the License.
import json
-from mock import Mock
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
@@ -36,13 +34,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def make_homeserver(self, reactor, clock):
-
- hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
- )
- return hs
-
def test_sync_argless(self):
request, channel = self.make_request("GET", "/sync")
self.render(request)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 7840f63fe3..00df0ea68e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -57,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
rooms_for_user = self.get_success(
- self.store.get_rooms_for_user_where_membership_is(
+ self.store.get_rooms_for_local_user_where_membership_is(
self.u_alice, [Membership.JOIN]
)
)
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..0d57eed268 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+ DirectServeResource,
+ JsonResource,
+ wrap_html_request_handler,
+)
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+ class TestResource(DirectServeResource):
+ callback = None
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self.callback(request)
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def test_good_response(self):
+ def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ body = channel.result["body"]
+ self.assertEqual(body, b"response")
+
+ def test_redirect_exception(self):
+ """
+ If the callback raises a RedirectException, it is turned into a 30x
+ with the right location.
+ """
+
+ def callback(request, **kwargs):
+ raise RedirectException(b"/look/an/eagle", 301)
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"301")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+ def test_redirect_exception_with_cookie(self):
+ """
+ If the callback raises a RedirectException which sets a cookie, that is
+ returned too
+ """
+
+ def callback(request, **kwargs):
+ e = RedirectException(b"/no/over/there", 304)
+ e.cookies.append(b"session=yespls")
+ raise e
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"304")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/no/over/there"])
+ cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+ self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self):
"""
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
new file mode 100644
index 0000000000..0ab0a91483
--- /dev/null
+++ b/tests/util/test_itertools.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.util.iterutils import chunk_seq
+
+from tests.unittest import TestCase
+
+
+class ChunkSeqTests(TestCase):
+ def test_short_seq(self):
+ parts = chunk_seq("123", 8)
+
+ self.assertEqual(
+ list(parts), ["123"],
+ )
+
+ def test_long_seq(self):
+ parts = chunk_seq("abcdefghijklmnop", 8)
+
+ self.assertEqual(
+ list(parts), ["abcdefgh", "ijklmnop"],
+ )
+
+ def test_uneven_parts(self):
+ parts = chunk_seq("abcdefghijklmnop", 5)
+
+ self.assertEqual(
+ list(parts), ["abcde", "fghij", "klmno", "p"],
+ )
+
+ def test_empty_input(self):
+ parts = chunk_seq([], 5)
+
+ self.assertEqual(
+ list(parts), [],
+ )
|