diff options
Diffstat (limited to 'synapse')
64 files changed, 922 insertions, 481 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 0dd538d804..abd5297390 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.8.0" +__version__ = "1.9.0.dev1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py index abbc7079a3..2cbfab2569 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,7 +14,6 @@ # limitations under the License. import logging -from typing import Dict, Tuple from six import itervalues @@ -35,7 +34,7 @@ from synapse.api.errors import ( ResourceLimitError, ) from synapse.config.server import is_threepid_reserved -from synapse.types import UserID +from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -509,10 +508,7 @@ class Auth(object): return self.store.is_server_admin(user) def compute_auth_events( - self, - event, - current_state_ids: Dict[Tuple[str, str], str], - for_verification: bool = False, + self, event, current_state_ids: StateMap[str], for_verification: bool = False, ): """Given an event and current state return the list of event IDs used to auth an event. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 9e9844b47c..1c9456e583 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,13 +17,15 @@ """Contains exceptions and error codes.""" import logging -from typing import Dict +from typing import Dict, List from six import iteritems from six.moves import http_client from canonicaljson import json +from twisted.web import http + logger = logging.getLogger(__name__) @@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError): self.msg = msg +class RedirectException(CodeMessageException): + """A pseudo-error indicating that we want to redirect the client to a different + location + + Attributes: + cookies: a list of set-cookies values to add to the response. For example: + b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT" + """ + + def __init__(self, location: bytes, http_code: int = http.FOUND): + """ + + Args: + location: the URI to redirect to + http_code: the HTTP response code + """ + msg = "Redirect to %s" % (location.decode("utf-8"),) + super().__init__(code=http_code, msg=msg) + self.location = location + + self.cookies = [] # type: List[bytes] + + class SynapseError(CodeMessageException): """A base exception type for matrix errors which have an errcode and error message (as well as an HTTP status code). diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 8e36bc57d3..1c7c6ec0c8 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer): class AdminCmdReplicationHandler(ReplicationClientHandler): - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): pass def get_streams_to_replicate(self): diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index e82e0f11e3..2217d4a4fb 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler): super(ASReplicationHandler, self).__init__(hs.get_datastore()) self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows) if stream_name == "events": max_stream_id = self.store.get_room_max_stream_ordering() diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 83c436229c..38d11fdd0f 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore()) self.send_handler = FederationSenderHandler(hs, self) - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(FederationSenderReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(FederationSenderReplicationHandler, self).on_rdata( stream_name, token, rows ) self.send_handler.process_replication_rows(stream_name, token, rows) @@ -159,6 +158,13 @@ class FederationSenderReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + # Let's wake up the transaction queue for the server in case we have + # pending stuff to send to it. + self.send_handler.wake_destination(server) + def start(config_options): try: @@ -206,7 +212,7 @@ class FederationSenderHandler(object): to the federation sender. """ - def __init__(self, hs, replication_client): + def __init__(self, hs: FederationSenderServer, replication_client): self.store = hs.get_datastore() self._is_mine_id = hs.is_mine_id self.federation_sender = hs.get_federation_sender() @@ -227,6 +233,9 @@ class FederationSenderHandler(object): self.store.get_room_max_stream_ordering() ) + def wake_destination(self, server: str): + self.federation_sender.wake_destination(server) + def stream_positions(self): return {"federation": self.federation_position} diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 09e639040a..e46b6ac598 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler): self.pusher_pool = hs.get_pusherpool() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.poke_pushers, stream_name, token, rows) @defer.inlineCallbacks diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 03031ee34d..3218da07bd 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler): self.presence_handler = hs.get_presence_handler() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, token, rows): + await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) run_in_background(self.process_and_notify, stream_name, token, rows) def get_streams_to_replicate(self): diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 1257098f92..ba536d6f04 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore()) self.user_directory = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_rdata(self, stream_name, token, rows): - yield super(UserDirectoryReplicationHandler, self).on_rdata( + async def on_rdata(self, stream_name, token, rows): + await super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) if stream_name == EventsStream.NAME: diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 35756bed87..74853f9faa 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -37,10 +37,12 @@ class EmailConfig(Config): self.email_enable_notifs = False - email_config = config.get("email", {}) + email_config = config.get("email") + if email_config is None: + email_config = {} - self.email_smtp_host = email_config.get("smtp_host", None) - self.email_smtp_port = email_config.get("smtp_port", None) + self.email_smtp_host = email_config.get("smtp_host", "localhost") + self.email_smtp_port = email_config.get("smtp_port", 25) self.email_smtp_user = email_config.get("smtp_user", None) self.email_smtp_pass = email_config.get("smtp_pass", None) self.require_transport_security = email_config.get( @@ -74,9 +76,9 @@ class EmailConfig(Config): self.email_template_dir = os.path.abspath(template_dir) self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_renewal_enabled = config.get("account_validity", {}).get( - "renew_at" - ) + + account_validity_config = config.get("account_validity") or {} + account_validity_renewal_enabled = account_validity_config.get("renew_at") self.threepid_behaviour_email = ( # Have Synapse handle the email sending if account_threepid_delegates.email @@ -278,7 +280,9 @@ class EmailConfig(Config): self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True ) - self.email_riot_base_url = email_config.get("riot_base_url", None) + self.email_riot_base_url = email_config.get( + "client_base_url", email_config.get("riot_base_url", None) + ) if account_validity_renewal_enabled: self.email_expiry_template_html = email_config.get( @@ -294,107 +298,111 @@ class EmailConfig(Config): raise ConfigError("Unable to find email template file %s" % (p,)) def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """ - # Enable sending emails for password resets, notification events or - # account expiry notices - # - # If your SMTP server requires authentication, the optional smtp_user & - # smtp_pass variables should be used - # - #email: - # enable_notifs: false - # smtp_host: "localhost" - # smtp_port: 25 # SSL: 465, STARTTLS: 587 - # smtp_user: "exampleusername" - # smtp_pass: "examplepassword" - # require_transport_security: false - # - # # notif_from defines the "From" address to use when sending emails. - # # It must be set if email sending is enabled. - # # - # # The placeholder '%(app)s' will be replaced by the application name, - # # which is normally 'app_name' (below), but may be overridden by the - # # Matrix client application. - # # - # # Note that the placeholder must be written '%(app)s', including the - # # trailing 's'. - # # - # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" - # - # # app_name defines the default value for '%(app)s' in notif_from. It - # # defaults to 'Matrix'. - # # - # #app_name: my_branded_matrix_server - # - # # Enable email notifications by default - # # - # notif_for_new_users: true - # - # # Defining a custom URL for Riot is only needed if email notifications - # # should contain links to a self-hosted installation of Riot; when set - # # the "app_name" setting is ignored - # # - # riot_base_url: "http://localhost/riot" - # - # # Configure the time that a validation email or text message code - # # will expire after sending - # # - # # This is currently used for password resets - # # - # #validation_token_lifetime: 1h - # - # # Template directory. All template files should be stored within this - # # directory. If not set, default templates from within the Synapse - # # package will be used - # # - # # For the list of default templates, please see - # # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # # - # #template_dir: res/templates - # - # # Templates for email notifications - # # - # notif_template_html: notif_mail.html - # notif_template_text: notif_mail.txt - # - # # Templates for account expiry notices - # # - # expiry_template_html: notice_expiry.html - # expiry_template_text: notice_expiry.txt - # - # # Templates for password reset emails sent by the homeserver - # # - # #password_reset_template_html: password_reset.html - # #password_reset_template_text: password_reset.txt - # - # # Templates for registration emails sent by the homeserver - # # - # #registration_template_html: registration.html - # #registration_template_text: registration.txt - # - # # Templates for validation emails sent by the homeserver when adding an email to - # # your user account - # # - # #add_threepid_template_html: add_threepid.html - # #add_threepid_template_text: add_threepid.txt - # - # # Templates for password reset success and failure pages that a user - # # will see after attempting to reset their password - # # - # #password_reset_template_success_html: password_reset_success.html - # #password_reset_template_failure_html: password_reset_failure.html - # - # # Templates for registration success and failure pages that a user - # # will see after attempting to register using an email or phone - # # - # #registration_template_success_html: registration_success.html - # #registration_template_failure_html: registration_failure.html + return """\ + # Configuration for sending emails from Synapse. # - # # Templates for success and failure pages that a user will see after attempting - # # to add an email or phone to their account - # # - # #add_threepid_success_html: add_threepid_success.html - # #add_threepid_failure_html: add_threepid_failure.html + email: + # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. + # + #smtp_host: mail.server + + # The port on the mail server for outgoing SMTP. Defaults to 25. + # + #smtp_port: 587 + + # Username/password for authentication to the SMTP server. By default, no + # authentication is attempted. + # + # smtp_user: "exampleusername" + # smtp_pass: "examplepassword" + + # Uncomment the following to require TLS transport security for SMTP. + # By default, Synapse will connect over plain text, and will then switch to + # TLS via STARTTLS *if the SMTP server supports it*. If this option is set, + # Synapse will refuse to connect unless the server supports STARTTLS. + # + #require_transport_security: true + + # Enable sending emails for messages that the user has missed + # + #enable_notifs: false + + # notif_from defines the "From" address to use when sending emails. + # It must be set if email sending is enabled. + # + # The placeholder '%(app)s' will be replaced by the application name, + # which is normally 'app_name' (below), but may be overridden by the + # Matrix client application. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" + + # app_name defines the default value for '%(app)s' in notif_from. It + # defaults to 'Matrix'. + # + #app_name: my_branded_matrix_server + + # Uncomment the following to disable automatic subscription to email + # notifications for new users. Enabled by default. + # + #notif_for_new_users: false + + # Custom URL for client links within the email notifications. By default + # links will be based on "https://matrix.to". + # + # (This setting used to be called riot_base_url; the old name is still + # supported for backwards-compatibility but is now deprecated.) + # + #client_base_url: "http://localhost/riot" + + # Configure the time that a validation email will expire after sending. + # Defaults to 1h. + # + #validation_token_lifetime: 15m + + # Directory in which Synapse will try to find the template files below. + # If not set, default templates from within the Synapse package will be used. + # + # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. + # If you *do* uncomment it, you will need to make sure that all the templates + # below are in the directory. + # + # Synapse will look for the following templates in this directory: + # + # * The contents of email notifications of missed events: 'notif_mail.html' and + # 'notif_mail.txt'. + # + # * The contents of account expiry notice emails: 'notice_expiry.html' and + # 'notice_expiry.txt'. + # + # * The contents of password reset emails sent by the homeserver: + # 'password_reset.html' and 'password_reset.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in the password reset email: 'password_reset_success.html' and + # 'password_reset_failure.html' + # + # * The contents of address verification emails sent during registration: + # 'registration.html' and 'registration.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in an address verification email sent during registration: + # 'registration_success.html' and 'registration_failure.html' + # + # * The contents of address verification emails sent when an address is added + # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' + # + # * HTML pages for success and failure that a user will see when they follow + # the link in an address verification email sent when an address is added + # to a Matrix account: 'add_threepid_success.html' and + # 'add_threepid_failure.html' + # + # You can see the default templates at: + # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates + # + #template_dir: "res/templates" """ diff --git a/synapse/config/push.py b/synapse/config/push.py index 0910958649..6f2b3a7faa 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -35,7 +35,7 @@ class PushConfig(Config): # Now check for the one in the 'email' section and honour it, # with a warning. - push_config = config.get("email", {}) + push_config = config.get("email") or {} redact_content = push_config.get("redact_content") if redact_content is not None: print( diff --git a/synapse/config/registration.py b/synapse/config/registration.py index ee9614c5f7..b873995a49 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -27,6 +27,8 @@ class AccountValidityConfig(Config): section = "accountvalidity" def __init__(self, config, synapse_config): + if config is None: + return self.enabled = config.get("enabled", False) self.renew_by_email_enabled = "renew_at" in config @@ -159,23 +161,6 @@ class RegistrationConfig(Config): # Optional account validity configuration. This allows for accounts to be denied # any request after a given period. # - # ``enabled`` defines whether the account validity feature is enabled. Defaults - # to False. - # - # ``period`` allows setting the period after which an account is valid - # after its registration. When renewing the account, its validity period - # will be extended by this amount of time. This parameter is required when using - # the account validity feature. - # - # ``renew_at`` is the amount of time before an account's expiry date at which - # Synapse will send an email to the account's email address with a renewal link. - # This needs the ``email`` and ``public_baseurl`` configuration sections to be - # filled. - # - # ``renew_email_subject`` is the subject of the email sent out with the renewal - # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter - # from the ``email`` section. - # # Once this feature is enabled, Synapse will look for registered users without an # expiration date at startup and will add one to every account it found using the # current settings at that time. @@ -186,21 +171,55 @@ class RegistrationConfig(Config): # date will be randomly selected within a range [now + period - d ; now + period], # where d is equal to 10%% of the validity period. # - #account_validity: - # enabled: true - # period: 6w - # renew_at: 1w - # renew_email_subject: "Renew your %%(app)s account" - # # Directory in which Synapse will try to find the HTML files to serve to the - # # user when trying to renew an account. Optional, defaults to - # # synapse/res/templates. - # template_dir: "res/templates" - # # HTML to be displayed to the user after they successfully renewed their - # # account. Optional. - # account_renewed_html_path: "account_renewed.html" - # # HTML to be displayed when the user tries to renew an account with an invalid - # # renewal token. Optional. - # invalid_token_html_path: "invalid_token.html" + account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %%(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + #template_dir: "res/templates" + + # File within 'template_dir' giving the HTML to be displayed to the user after + # they successfully renewed their account. If not set, default text is used. + # + #account_renewed_html_path: "account_renewed.html" + + # File within 'template_dir' giving the HTML to be displayed when the user + # tries to renew an account with an invalid renewal token. If not set, + # default text is used. + # + #invalid_token_html_path: "invalid_token.html" # Time that a user's session remains valid for, after they log in. # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index b91414aa35..423c158b11 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -121,6 +121,7 @@ class SAML2Config(Config): required_methods = [ "get_saml_attributes", "saml_response_to_user_attributes", + "get_remote_user_id", ] missing_methods = [ method diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a44baea365..9ea85e93e6 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from six import iteritems @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.appservice import ApplicationService from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import StateMap @attr.s(slots=True) @@ -106,13 +107,11 @@ class EventContext: _state_group = attr.ib(default=None, type=Optional[int]) state_group_before_event = attr.ib(default=None, type=Optional[int]) prev_group = attr.ib(default=None, type=Optional[int]) - delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + delta_ids = attr.ib(default=None, type=Optional[StateMap[str]]) app_service = attr.ib(default=None, type=Optional[ApplicationService]) - _current_state_ids = attr.ib( - default=None, type=Optional[Dict[Tuple[str, str], str]] - ) - _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) + _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) @staticmethod def with_state( diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index ced4925a98..174f6e42be 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object): def federation_ack(self, token): self._clear_queue_before_pos(token) - def get_replication_rows(self, from_token, to_token, limit, federation_ack=None): + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): """Get rows to be sent over federation between the two tokens Args: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4ebb0e8bc0..36c83c3027 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -21,6 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer +import synapse import synapse.metrics from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter( class FederationSender(object): - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.server_name = hs.hostname @@ -482,7 +483,20 @@ class FederationSender(object): def send_device_messages(self, destination): if destination == self.server_name: - logger.info("Not sending device update to ourselves") + logger.warning("Not sending device update to ourselves") + return + + self._get_per_destination_queue(destination).attempt_new_transaction() + + def wake_destination(self, destination: str): + """Called when we want to retry sending transactions to a remote. + + This is mainly useful if the remote server has been down and we think it + might have come back. + """ + + if destination == self.server_name: + logger.warning("Not waking up ourselves") return self._get_per_destination_queue(destination).attempt_new_transaction() diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index a5b36b1827..5012aaea35 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState +from synapse.types import StateMap from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -77,7 +78,7 @@ class PerDestinationQueue(object): # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] + self._pending_edus_keyed = {} # type: StateMap[Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index b4cbf23394..d8cf9ed299 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -44,6 +44,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string @@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError): class Authenticator(object): - def __init__(self, hs): + def __init__(self, hs: HomeServer): self._clock = hs.get_clock() self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifer = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request(self, request, content): @@ -166,6 +172,17 @@ class Authenticator(object): try: logger.info("Marking origin %r as up", origin) await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifer.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + except Exception: logger.exception("Error resetting retry timings on %s", origin) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 76d18a8ba8..60a7c938bc 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,9 +14,11 @@ # limitations under the License. import logging +from typing import List from synapse.api.constants import Membership -from synapse.types import RoomStreamToken +from synapse.events import FrozenEvent +from synapse.types import RoomStreamToken, StateMap from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -134,7 +136,7 @@ class AdminHandler(BaseHandler): The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = await self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_local_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -259,35 +261,26 @@ class ExfiltrationWriter(object): """Interface used to specify how to write exported data. """ - def write_events(self, room_id, events): + def write_events(self, room_id: str, events: List[FrozenEvent]): """Write a batch of events for a room. - - Args: - room_id (str) - events (list[FrozenEvent]) """ pass - def write_state(self, room_id, event_id, state): + def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]): """Write the state at the given event in the room. This only gets called for backward extremities rather than for each event. - - Args: - room_id (str) - event_id (str) - state (dict[tuple[str, str], FrozenEvent]) """ pass - def write_invite(self, room_id, event, state): + def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]): """Write an invite for the room, with associated invite state. Args: - room_id (str) - event (FrozenEvent) - state (dict[tuple[str, str], dict]): A subset of the state at the + room_id + event + state: A subset of the state at the invite, with a subset of the event keys (type, state_key content and sender) """ diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 4426967f88..2afb390a92 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler): user_id (str): The user ID to reject pending invites for. """ user = UserID.from_string(user_id) - pending_invites = await self.store.get_invited_rooms_for_user(user_id) + pending_invites = await self.store.get_invited_rooms_for_local_user(user_id) for room in pending_invites: try: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 61b6713c88..d4f9a792fc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -64,7 +64,7 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.types import UserID, get_domain_from_id +from synapse.types import StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -89,7 +89,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None) + auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) def shortstr(iterable, maxitems=5): @@ -352,9 +352,7 @@ class FederationHandler(BaseHandler): ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list( - ours.values() - ) # type: list[dict[tuple[str, str], str]] + state_maps = list(ours.values()) # type: list[StateMap[str]] # we don't need this any more, let's delete it. del ours @@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[Dict[Tuple[str, str], EventBase]], + auth_events: Optional[StateMap[EventBase]], backfilled: bool, ): """ diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 44ec3e66ae..2e6755f19c 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = await self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=memberships ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9cab2adbfb..9f50196ea7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + Requester, + RoomAlias, + RoomID, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache @@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state, + self, + requester: Requester, + old_room_id: str, + new_room_id: str, + old_room_state: StateMap[str], ): """Send updated power levels in both rooms after an upgrade Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (str): the id of the room to be replaced - new_room_id (str): the id of the replacement room - old_room_state (dict[tuple[str, str], str]): the state map for the old room + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_room_id: the id of the replacement room + old_room_state: the state map for the old room Returns: Deferred diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 03bb52ccfb..15e8aa5249 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -690,7 +690,7 @@ class RoomMemberHandler(object): @defer.inlineCallbacks def _get_inviter(self, user_id, room_id): - invite = yield self.store.get_invite_for_user_in_room( + invite = yield self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 107f97032b..7f411b53b9 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -32,6 +32,7 @@ from synapse.types import ( mxid_localpart_allowed_characters, ) from synapse.util.async_helpers import Linearizer +from synapse.util.iterutils import chunk_seq logger = logging.getLogger(__name__) @@ -132,17 +133,28 @@ class SamlHandler: logger.warning("SAML2 response was not signed") raise SynapseError(400, "SAML2 response was not signed") - logger.info("SAML2 response: %s", saml2_auth.origxml) - logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) + logger.debug("SAML2 response: %s", saml2_auth.origxml) + for assertion in saml2_auth.assertions: + # kibana limits the length of a log field, whereas this is all rather + # useful, so split it up. + count = 0 + for part in chunk_seq(str(assertion), 10000): + logger.info( + "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part + ) + count += 1 - try: - remote_user_id = saml2_auth.ava["uid"][0] - except KeyError: - logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "'uid' not in SAML2 response") + logger.info("SAML2 mapped attributes: %s", saml2_auth.ava) self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) + remote_user_id = self._user_mapping_provider.get_remote_user_id( + saml2_auth, client_redirect_url + ) + + if not remote_user_id: + raise Exception("Failed to extract remote user id from SAML response") + with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -279,6 +291,20 @@ class DefaultSamlMappingProvider(object): self._mxid_source_attribute = parsed_config.mxid_source_attribute self._mxid_mapper = parsed_config.mxid_mapper + self._grandfathered_mxid_source_attribute = ( + module_api._hs.config.saml2_grandfathered_mxid_source_attribute + ) + + def get_remote_user_id( + self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str + ): + """Extracts the remote user id from the SAML response""" + try: + return saml_response.ava["uid"][0] + except KeyError: + logger.warning("SAML2 response lacks a 'uid' attestation") + raise SynapseError(400, "'uid' not in SAML2 response") + def saml_response_to_user_attributes( self, saml_response: saml2.response.AuthnResponse, diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index ef750d1497..110097eab9 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -179,7 +179,7 @@ class SearchHandler(BaseHandler): search_filter = Filter(filter_dict) # TODO: Search through left rooms too - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = yield self.store.get_rooms_for_local_user_where_membership_is( user.to_string(), membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2d3b8ba73c..cd95f85e3f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1662,7 +1662,7 @@ class SyncHandler(object): Membership.BAN, ) - room_list = await self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=membership_list ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b635c339ed..d5ca9cb07b 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -257,7 +257,7 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates(self, last_id, current_id): if last_id == current_id: return [] diff --git a/synapse/http/server.py b/synapse/http/server.py index 943d12c907..04bc2385a2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cgi import collections +import html import http.client import logging import types @@ -36,6 +36,7 @@ import synapse.metrics from synapse.api.errors import ( CodeMessageException, Codes, + RedirectException, SynapseError, UnrecognizedRequestError, ) @@ -153,14 +154,18 @@ def _return_html_error(f, request): Args: f (twisted.python.failure.Failure): - request (twisted.web.iweb.IRequest): + request (twisted.web.server.Request): """ if f.check(CodeMessageException): cme = f.value code = cme.code msg = cme.msg - if isinstance(cme, SynapseError): + if isinstance(cme, RedirectException): + logger.info("%s redirect to %s", request, cme.location) + request.setHeader(b"location", cme.location) + request.cookies.extend(cme.cookies) + elif isinstance(cme, SynapseError): logger.info("%s SynapseError: %s - %s", request, code, msg) else: logger.error( @@ -178,7 +183,7 @@ def _return_html_error(f, request): exc_info=(f.type, f.value, f.getTracebackObject()), ) - body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") + body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8") request.setResponseCode(code) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Length", b"%i" % (len(body),)) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 305b9b0178..d680ee95e1 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,18 +17,26 @@ import logging from twisted.internet import defer +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import UserID +""" +This package defines the 'stable' API which can be used by extension modules which +are loaded into Synapse. +""" + +__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"] + logger = logging.getLogger(__name__) class ModuleApi(object): - """A proxy object that gets passed to password auth providers so they + """A proxy object that gets passed to various plugin modules so they can register new users etc if necessary. """ def __init__(self, hs, auth_handler): - self.hs = hs + self._hs = hs self._store = hs.get_datastore() self._auth = hs.get_auth() @@ -64,7 +73,7 @@ class ModuleApi(object): """ if username.startswith("@"): return username - return UserID(username, self.hs.hostname).to_string() + return UserID(username, self._hs.hostname).to_string() def check_user_exists(self, user_id): """Check if user exists. @@ -111,10 +120,14 @@ class ModuleApi(object): displayname (str|None): The displayname of the new user. emails (List[str]): Emails to bind to the new user. + Raises: + SynapseError if there is an error performing the registration. Check the + 'errcode' property for more information on the reason for failure + Returns: Deferred[str]: user_id """ - return self.hs.get_registration_handler().register_user( + return self._hs.get_registration_handler().register_user( localpart=localpart, default_display_name=displayname, bind_emails=emails ) @@ -131,12 +144,34 @@ class ModuleApi(object): Returns: defer.Deferred[tuple[str, str]]: Tuple of device ID and access token """ - return self.hs.get_registration_handler().register_device( + return self._hs.get_registration_handler().register_device( user_id=user_id, device_id=device_id, initial_display_name=initial_display_name, ) + def record_user_external_id( + self, auth_provider_id: str, remote_user_id: str, registered_user_id: str + ) -> defer.Deferred: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + return self._store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) + + def generate_short_term_login_token( + self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + ) -> str: + """Generate a login token suitable for m.login.token authentication""" + return self._hs.get_macaroon_generator().generate_short_term_login_token( + user_id, duration_in_ms + ) + @defer.inlineCallbacks def invalidate_access_token(self, access_token): """Invalidate an access token for a user @@ -157,7 +192,7 @@ class ModuleApi(object): user_id = user_info["user"].to_string() if device_id: # delete the device, which will also delete its access tokens - yield self.hs.get_device_handler().delete_device(user_id, device_id) + yield self._hs.get_device_handler().delete_device(user_id, device_id) else: # no associated device. Just delete the access token. yield self._auth_handler.delete_access_token(access_token) diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py new file mode 100644 index 0000000000..b15441772c --- /dev/null +++ b/synapse/module_api/errors.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exception types which are exposed as part of the stable module API""" + +from synapse.api.errors import RedirectException, SynapseError # noqa: F401 diff --git a/synapse/notifier.py b/synapse/notifier.py index 5f5f765bea..6132727cbd 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,11 +15,13 @@ import logging from collections import namedtuple +from typing import Callable, List from prometheus_client import Counter from twisted.internet import defer +import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state @@ -154,7 +156,7 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.user_to_user_stream = {} self.room_to_user_streams = {} @@ -164,7 +166,12 @@ class Notifier(object): self.store = hs.get_datastore() self.pending_new_room_events = [] - self.replication_callbacks = [] + # Called when there are new things to stream over replication + self.replication_callbacks = [] # type: List[Callable[[], None]] + + # Called when remote servers have come back online after having been + # down. + self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -205,7 +212,7 @@ class Notifier(object): "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) - def add_replication_callback(self, cb): + def add_replication_callback(self, cb: Callable[[], None]): """Add a callback that will be called when some new data is available. Callback is not given any arguments. It should *not* return a Deferred - if it needs to do any asynchronous work, a background thread should be started and @@ -213,6 +220,12 @@ class Notifier(object): """ self.replication_callbacks.append(cb) + def add_remote_server_up_callback(self, cb: Callable[[str], None]): + """Add a callback that will be called when synapse detects a server + has been + """ + self.remote_server_up_callbacks.append(cb) + def on_new_room_event( self, event, room_stream_id, max_room_stream_id, extra_users=[] ): @@ -522,3 +535,15 @@ class Notifier(object): """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() + + def notify_remote_server_up(self, server: str): + """Notify any replication that a remote server has come back up + """ + # We call federation_sender directly rather than registering as a + # callback as a) we already have a reference to it and b) it introduces + # circular dependencies. + if self.federation_sender: + self.federation_sender.wake_destination(server) + + for cb in self.remote_server_up_callbacks: + cb(server) diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index de5c101a58..5dae4648c0 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,7 +21,7 @@ from synapse.storage import Storage @defer.inlineCallbacks def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_user(user_id) + invites = yield store.get_invited_rooms_for_local_user(user_id) joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 29f35b9915..3aa6cb8b96 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -152,7 +152,7 @@ class SlavedEventStore( if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_user.invalidate((state_key,)) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) if relates_to: self.get_relations_for_event.invalidate_many((relates_to,)) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index aa7fd90e26..fc06a7b053 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self.factory) - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) - return self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, token, rows) - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes the slave store. Can be overriden in subclasses to handle more. """ - return self.store.process_replication_rows(stream_name, token, []) + self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting @@ -146,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): if d: d.callback(data) + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index cbb36b9acf..451671412d 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -387,6 +387,20 @@ class UserIpCommand(Command): ) +class RemoteServerUpCommand(Command): + """Sent when a worker has detected that a remote server is no longer + "down" and retry timings should be reset. + + If sent from a client the server will relay to all other workers. + + Format:: + + REMOTE_SERVER_UP <server> + """ + + NAME = "REMOTE_SERVER_UP" + + _COMMANDS = ( ServerCommand, RdataCommand, @@ -401,6 +415,7 @@ _COMMANDS = ( RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, + RemoteServerUpCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = ( ErrorCommand.NAME, PingCommand.NAME, SyncCommand.NAME, + RemoteServerUpCommand.NAME, ) # The commands the client is allowed to send @@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = ( InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, + RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index db0353c996..131e5acb09 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -76,17 +76,17 @@ from synapse.replication.tcp.commands import ( PingCommand, PositionCommand, RdataCommand, + RemoteServerUpCommand, ReplicateCommand, ServerCommand, SyncCommand, UserSyncCommand, ) +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -from .streams import STREAMS_MAP - connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -241,19 +241,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) - def handle_command(self, cmd): + async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND> + By default delegates to on_<COMMAND>, which should return an awaitable. Args: - cmd (synapse.replication.tcp.commands.Command): received command - - Returns: - Deferred + cmd: received command """ handler = getattr(self, "on_%s" % (cmd.NAME,)) - return handler(cmd) + await handler(cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -326,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + async def on_PING(self, line): self.received_ping = True - def on_ERROR(self, cmd): + async def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -429,16 +426,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) self.streamer.new_connection(self) - def on_NAME(self, cmd): + async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - def on_USER_SYNC(self, cmd): - return self.streamer.on_user_sync( + async def on_USER_SYNC(self, cmd): + await self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - def on_REPLICATE(self, cmd): + async def on_REPLICATE(self, cmd): stream_name = cmd.stream_name token = cmd.token @@ -449,23 +446,26 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): for stream in iterkeys(self.streamer.streams_by_name) ] - return make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) else: - return self.subscribe_to_stream(stream_name, token) + await self.subscribe_to_stream(stream_name, token) - def on_FEDERATION_ACK(self, cmd): - return self.streamer.federation_ack(cmd.token) + async def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) - def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + async def on_REMOVE_PUSHER(self, cmd): + await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - def on_INVALIDATE_CACHE(self, cmd): - return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_INVALIDATE_CACHE(self, cmd): + self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - def on_USER_IP(self, cmd): - return self.streamer.on_user_ip( + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.streamer.on_remote_server_up(cmd.data) + + async def on_USER_IP(self, cmd): + self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, @@ -474,8 +474,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - @defer.inlineCallbacks - def subscribe_to_stream(self, stream_name, token): + async def subscribe_to_stream(self, stream_name, token): """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those @@ -487,7 +486,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates - updates, current_token = yield self.streamer.get_stream_updates( + updates, current_token = await self.streamer.get_stream_updates( stream_name, token ) @@ -560,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def send_sync(self, data): self.send_command(SyncCommand(data)) + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) self.streamer.lost_connection(self) @@ -572,7 +574,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. Args: @@ -580,14 +582,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ raise NotImplementedError() @abc.abstractmethod - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data.""" raise NotImplementedError() @@ -597,6 +596,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + async def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + raise NotImplementedError() + + @abc.abstractmethod def get_streams_to_replicate(self): """Called when a new connection has been established and we need to subscribe to streams. @@ -676,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): if not self.streams_connecting: self.handler.finished_connecting() - def on_SERVER(self, cmd): + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def on_RDATA(self, cmd): + async def on_RDATA(self, cmd): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -701,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - return self.handler.on_rdata(stream_name, cmd.token, rows) + await self.handler.on_rdata(stream_name, cmd.token, rows) - def on_POSITION(self, cmd): + async def on_POSITION(self, cmd): # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - return self.handler.on_position(cmd.stream_name, cmd.token) + await self.handler.on_position(cmd.stream_name, cmd.token) + + async def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) - def on_SYNC(self, cmd): - return self.handler.on_sync(cmd.data) + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.handler.on_remote_server_up(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index cbfdaf5773..6ebf944f66 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -23,7 +23,6 @@ from six import itervalues from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.protocol import Factory from synapse.metrics import LaterGauge @@ -121,6 +120,7 @@ class ReplicationStreamer(object): self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) + self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False @@ -155,8 +155,7 @@ class ReplicationStreamer(object): run_as_background_process("replication_notifier", self._run_notifier_loop) - @defer.inlineCallbacks - def _run_notifier_loop(self): + async def _run_notifier_loop(self): self.is_looping = True try: @@ -185,7 +184,7 @@ class ReplicationStreamer(object): continue if self._replication_torture_level: - yield self.clock.sleep( + await self.clock.sleep( self._replication_torture_level / 1000.0 ) @@ -196,7 +195,7 @@ class ReplicationStreamer(object): stream.upto_token, ) try: - updates, current_token = yield stream.get_updates() + updates, current_token = await stream.get_updates() except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -233,7 +232,7 @@ class ReplicationStreamer(object): self.is_looping = False @measure_func("repl.get_stream_updates") - def get_stream_updates(self, stream_name, token): + async def get_stream_updates(self, stream_name, token): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -241,7 +240,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return stream.get_updates_since(token) + return await stream.get_updates_since(token) @measure_func("repl.federation_ack") def federation_ack(self, token): @@ -252,22 +251,20 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - @defer.inlineCallbacks - def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - yield self.presence_handler.update_external_syncs_row( + await self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") - @defer.inlineCallbacks - def on_remove_pusher(self, app_id, push_key, user_id): + async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher """ remove_pusher_counter.inc() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id=app_id, pushkey=push_key, user_id=user_id ) @@ -281,15 +278,24 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") - @defer.inlineCallbacks - def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + async def on_user_ip( + self, user_id, access_token, ip, user_agent, device_id, last_seen + ): """The client saw a user request """ user_ip_cache_counter.inc() - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen ) - yield self._server_notices_sender.on_user_ip(user_id) + await self._server_notices_sender.on_user_ip(user_id) + + @measure_func("repl.on_remote_server_up") + def on_remote_server_up(self, server: str): + self.notifier.notify_remote_server_up(server) + + def send_remote_server_up(self, server: str): + for conn in self.connections: + conn.send_remote_server_up(server) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4ab0334fc1..e03e77199b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -19,8 +19,6 @@ import logging from collections import namedtuple from typing import Any -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -144,8 +142,7 @@ class Stream(object): self.upto_token = self.current_token() self.last_token = self.upto_token - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self): """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before), until the `upto_token` @@ -156,13 +153,12 @@ class Stream(object): list of ``(token, row)`` entries. ``row`` will be json-serialised and sent over the replication steam. """ - updates, current_token = yield self.get_updates_since(self.last_token) + updates, current_token = await self.get_updates_since(self.last_token) self.last_token = current_token return updates, current_token - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since(self, from_token): """Like get_updates except allows specifying from when we should stream updates @@ -182,15 +178,16 @@ class Stream(object): if from_token == current_token: return [], current_token + logger.info("get_updates_since: %s", self.__class__) if self._LIMITED: - rows = yield self.update_function( + rows = await self.update_function( from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function(from_token, current_token) + rows = await self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -295,9 +292,8 @@ class PushRulesStream(Stream): push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) + async def update_function(self, from_token, to_token, limit): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) return [(row[0], row[2]) for row in rows] @@ -413,9 +409,8 @@ class AccountDataStream(Stream): super(AccountDataStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( + async def update_function(self, from_token, to_token, limit): + global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 0843e5aa90..b3afabb8cd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,8 +19,6 @@ from typing import Tuple, Type import attr -from twisted.internet import defer - from ._base import Stream @@ -122,16 +120,15 @@ class EventsStream(Stream): super(EventsStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( + async def update_function(self, from_token, current_token, limit=None): + event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) event_updates = ( (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) - state_rows = yield self._store.get_all_updated_current_state_deltas( + state_rows = await self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 7a256b6ecb..50e080673b 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -206,10 +206,6 @@ class AuthRestServlet(RestServlet): return None elif stagetype == LoginType.TERMS: - if ("session" not in request.args or len(request.args["session"])) == 0: - raise SynapseError(400, "No session supplied") - - session = request.args["session"][0] authdict = {"session": session} success = await self.auth_handler.add_oob_auth( diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2a477ad22e..3d0fefb4df 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -71,6 +71,8 @@ class VersionsRestServlet(RestServlet): # Implements support for label-based filtering as described in # MSC2326. "org.matrix.label_based_filtering": True, + # Implements support for cross signing as described in MSC1756 + "org.matrix.e2e_cross_signing": True, }, }, ) diff --git a/synapse/server.pyi b/synapse/server.pyi index b5e0b57095..0731403047 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +import twisted.internet + import synapse.api.auth import synapse.config.homeserver import synapse.federation.sender @@ -9,10 +11,12 @@ import synapse.handlers.deactivate_account import synapse.handlers.device import synapse.handlers.e2e_keys import synapse.handlers.message +import synapse.handlers.presence import synapse.handlers.room import synapse.handlers.room_member import synapse.handlers.set_password import synapse.http.client +import synapse.notifier import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -85,3 +89,11 @@ class HomeServer(object): self, ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender: pass + def get_notifier(self) -> synapse.notifier.Notifier: + pass + def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler: + pass + def get_clock(self) -> synapse.util.Clock: + pass + def get_reactor(self) -> twisted.internet.base.ReactorBase: + pass diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 2dac90578c..f7432c8d2f 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -105,7 +105,7 @@ class ServerNoticesManager(object): assert self._is_mine_id(user_id), "Cannot send server notices to remote users" - rooms = yield self._store.get_rooms_for_user_where_membership_is( + rooms = yield self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) system_mxid = self._config.server_notices_mxid diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5accc071ab..cacd0c0c2b 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional from six import iteritems, itervalues @@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids): def resolve_events_with_store( room_id: str, room_version: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ): diff --git a/synapse/state/v1.py b/synapse/state/v1.py index b2f9865f39..d6c34ce3b7 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,7 +15,7 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional from six import iteritems, iterkeys, itervalues @@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks def resolve_events_with_store( room_id: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_map_factory: Callable, ): diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 72fb8a6317..6216fdd204 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,7 +16,7 @@ import heapq import itertools import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from six import iteritems, itervalues @@ -27,6 +27,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) def resolve_events_with_store( room_id: str, room_version: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", ): @@ -393,12 +394,12 @@ def _iterative_auth_checks( room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (dict[tuple[str, str], str]): The set of state to start with + base_state (StateMap[str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: - Deferred[dict[tuple[str, str], str]]: Returns the final updated state + Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 54ed8574c4..bf91512daf 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore from synapse.storage.engines import PostgresEngine -from synapse.util import batch_iter +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 9a828231c4..f0a7962dd0 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -33,13 +33,13 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key -from synapse.util import batch_iter from synapse.util.caches.descriptors import ( Cache, cached, cachedInlineCallbacks, cachedList, ) +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 58f35d7f56..bb69c20448 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -43,9 +43,9 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id -from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) @@ -128,6 +128,7 @@ class EventsStore( hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self.is_mine_id = hs.is_mine_id @defer.inlineCallbacks def _read_forward_extremities(self): @@ -547,6 +548,34 @@ class EventsStore( ], ) + # Note: Do we really want to delete rows here (that we do not + # subsequently reinsert below)? While technically correct it means + # we have no record of the fact the user *was* a member of the + # room but got, say, state reset out of it. + if to_delete or to_insert: + txn.executemany( + "DELETE FROM local_current_membership" + " WHERE room_id = ? AND user_id = ?", + ( + (room_id, state_key) + for etype, state_key in itertools.chain(to_delete, to_insert) + if etype == EventTypes.Member and self.is_mine_id(state_key) + ), + ) + + if to_insert: + txn.executemany( + """INSERT INTO local_current_membership + (room_id, user_id, event_id, membership) + VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) + """, + [ + (room_id, key[1], ev_id, ev_id) + for key, ev_id in to_insert.items() + if key[0] == EventTypes.Member and self.is_mine_id(key[1]) + ], + ) + txn.call_after( self._curr_state_delta_stream_cache.entity_has_changed, room_id, @@ -1724,6 +1753,7 @@ class EventsStore( "local_invites", "room_account_data", "room_tags", + "local_current_membership", ): logger.info("[purge] removing %s from %s", room_id, table) txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 0cce5232f5..3b93e0597a 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -37,8 +37,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util import batch_iter from synapse.util.caches.descriptors import Cache +from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index 6b12f5a75f..ba89c68c9f 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -23,8 +23,8 @@ from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore from synapse.storage.keys import FetchKeyResult -from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index a2c83e0867..604c8b7ddd 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -17,8 +17,8 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.presence import UserPresenceState -from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 70ff5751b6..9acef7c950 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0]: row[1] for row in txn} @cached() - def get_invited_rooms_for_user(self, user_id): - """ Get all the rooms the user is invited to + def get_invited_rooms_for_local_user(self, user_id): + """ Get all the rooms the *local* user is invited to + Args: user_id (str): The user ID. Returns: A deferred list of RoomsForUser. """ - return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) + return self.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE] + ) @defer.inlineCallbacks - def get_invite_for_user_in_room(self, user_id, room_id): - """Gets the invite for the given user and room + def get_invite_for_local_user_in_room(self, user_id, room_id): + """Gets the invite for the given *local* user and room Args: user_id (str) @@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred: Resolves to either a RoomsForUser or None if no invite was found. """ - invites = yield self.get_invited_rooms_for_user(user_id) + invites = yield self.get_invited_rooms_for_local_user(user_id) for invite in invites: if invite.room_id == room_id: return invite return None @defer.inlineCallbacks - def get_rooms_for_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this user where the membership for this user + def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this *local* user where the membership for this user matches one in the membership list. Filters out forgotten rooms. @@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): return defer.succeed(None) rooms = yield self.db.runInteraction( - "get_rooms_for_user_where_membership_is", - self._get_rooms_for_user_where_membership_is_txn, + "get_rooms_for_local_user_where_membership_is", + self._get_rooms_for_local_user_where_membership_is_txn, user_id, membership_list, ) @@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore): forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) return [room for room in rooms if room.room_id not in forgotten_rooms] - def _get_rooms_for_user_where_membership_is_txn( + def _get_rooms_for_local_user_where_membership_is_txn( self, txn, user_id, membership_list ): + # Paranoia check. + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r" + % (user_id,), + ) - do_invite = Membership.INVITE in membership_list - membership_list = [m for m in membership_list if m != Membership.INVITE] - - results = [] - if membership_list: - if self._current_state_events_membership_up_to_date: - clause, args = make_in_list_sql_clause( - self.database_engine, "c.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - else: - clause, args = make_in_list_sql_clause( - self.database_engine, "m.membership", membership_list - ) - sql = """ - SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND state_key = ? - AND %s - """ % ( - clause, - ) - - txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] + clause, args = make_in_list_sql_clause( + self.database_engine, "c.membership", membership_list + ) - if do_invite: - sql = ( - "SELECT i.room_id, inviter, i.event_id, e.stream_ordering" - " FROM local_invites as i" - " INNER JOIN events as e USING (event_id)" - " WHERE invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) + sql = """ + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + FROM local_current_membership AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + user_id = ? + AND %s + """ % ( + clause, + ) - txn.execute(sql, (user_id,)) - results.extend( - RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) - for r in self.db.cursor_to_dict(txn) - ) + txn.execute(sql, (user_id, *args)) + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] return results - @cachedInlineCallbacks(max_entries=500000, iterable=True) + @cached(max_entries=500000, iterable=True) def get_rooms_for_user_with_stream_ordering(self, user_id): - """Returns a set of room_ids the user is currently joined to + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. Args: user_id (str) @@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore): the rooms the user is in currently, along with the stream ordering of the most recent join for that user and room. """ - rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN] - ) - return frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms + return self.db.runInteraction( + "get_rooms_for_user_with_stream_ordering", + self._get_rooms_for_user_with_stream_ordering_txn, + user_id, ) + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + # We use `current_state_events` here and not `local_current_membership` + # as a) this gets called with remote users and b) this only gets called + # for rooms the server is participating in. + if self._current_state_events_membership_up_to_date: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND c.membership = ? + """ + else: + sql = """ + SELECT room_id, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND state_key = ? + AND m.membership = ? + """ + + txn.execute(sql, (user_id, Membership.JOIN)) + results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + + return results + @defer.inlineCallbacks def get_rooms_for_user(self, user_id, on_invalidate=None): - """Returns a set of room_ids the user is currently joined to + """Returns a set of room_ids the user is currently joined to. + + If a remote user only returns rooms this server is currently + participating in. """ rooms = yield self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate @@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): event.internal_metadata.stream_ordering, ) txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) + self.get_invited_rooms_for_local_user.invalidate, (event.state_key,) ) # We update the local_invites table only if the event is "current", @@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): ), ) + # We also update the `local_current_membership` table with + # latest invite info. This will usually get updated by the + # `current_state_events` handling, unless its an outlier. + if event.internal_metadata.is_outlier(): + # This should only happen for out of band memberships, so + # we add a paranoia check. + assert event.internal_metadata.is_out_of_band_membership() + + self.db.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) + @defer.inlineCallbacks def locally_reject_invite(self, user_id, room_id): sql = ( @@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def f(txn, stream_ordering): txn.execute(sql, (stream_ordering, True, room_id, user_id)) + # We also clear this entry from `local_current_membership`. + # Ideally we'd point to a leave event, but we don't have one, so + # nevermind. + self.db.simple_delete_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + ) + with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py new file mode 100644 index 0000000000..601c236c4a --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# We create a new table called `local_current_membership` that stores the latest +# membership state of local users in rooms, which helps track leaves/bans/etc +# even if the server has left the room (and so has deleted the room from +# `current_state_events`). This will also include outstanding invites for local +# users for rooms the server isn't in. +# +# If the server isn't and hasn't been in the room then it will only include +# outsstanding invites, and not e.g. pre-emptive bans of local users. +# +# If the server later rejoins a room `local_current_membership` can simply be +# replaced with the new current state of the room (which results in the +# equivalent behaviour as if the server had remained in the room). + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # We need to do the insert in `run_upgrade` section as we don't have access + # to `config` in `run_create`. + + # This upgrade may take a bit of time for large servers (e.g. one minute for + # matrix.org) but means we avoid a lots of book keeping required to do it as + # a background update. + + # We check if the `current_state_events.membership` is up to date by + # checking if the relevant background update has finished. If it has + # finished we can avoid doing a join against `room_memberships`, which + # speesd things up. + cur.execute( + """SELECT 1 FROM background_updates + WHERE update_name = 'current_state_events_membership' + """ + ) + current_state_membership_up_to_date = not bool(cur.fetchone()) + + # Cheekily drop and recreate indices, as that is faster. + cur.execute("DROP INDEX local_current_membership_idx") + cur.execute("DROP INDEX local_current_membership_room_idx") + + if current_state_membership_up_to_date: + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, c.membership + FROM current_state_events AS c + WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ? + """ + else: + # We can't rely on the membership column, so we need to join against + # `room_memberships`. + sql = """ + INSERT INTO local_current_membership (room_id, user_id, event_id, membership) + SELECT c.room_id, state_key AS user_id, event_id, r.membership + FROM current_state_events AS c + INNER JOIN room_memberships AS r USING (event_id) + WHERE type = 'm.room.member' and state_key like '%' || ? + """ + cur.execute(sql, (config.server_name,)) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) + + +def run_create(cur, database_engine, *args, **kwargs): + cur.execute( + """ + CREATE TABLE local_current_membership ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + membership TEXT NOT NULL + )""" + ) + + cur.execute( + "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)" + ) + cur.execute( + "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)" + ) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index d07440e3ed..33bebd1c48 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): + def get_filtered_current_state_ids( + self, room_id: str, state_filter: StateFilter = StateFilter.all() + ): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state Args: - room_id (str) - state_filter (StateFilter): The state filter used to fetch state + room_id + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[tuple[str, str], str]]: Map from type/state_key to - event ID. + defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index d53695f238..c4ee9b7ccb 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Dict, Iterable, List, Set, Tuple from six import iteritems from six.moves import range @@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): + """Returns the state groups for a given set of groups from the + database, filtering on types of state events. Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state + groups: list of state group IDs to query + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ results = {} @@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want + groups: list of state groups for which we want to get the state. - state_filter (StateFilter): The state filter used to fetch state + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + def _get_state_for_groups_using_cache( + self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter + ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. + groups: list of state groups for which we want to get the state. + cache: the cache of group ids to state dicts which + we will pass through - either the normal state cache or the + specific members state cache. + state_filter: The state filter used to fetch state from the + database. Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. + Tuple of dict of state_group_id to state map of entries in the + cache, and the state group ids either missing from the cache or + incomplete. """ results = {} incomplete_groups = set() diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e70026b80a..e86984cd50 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 56 +SCHEMA_VERSION = 57 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index d6a7bd7834..fdc0abf5cf 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -34,7 +34,7 @@ class PurgeEventsStorage(object): """ state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.main.purge_room_state(room_id, state_groups_to_delete) + yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) @defer.inlineCallbacks def purge_history(self, room_id, token, delete_local_events): diff --git a/synapse/storage/state.py b/synapse/storage/state.py index cbeb586014..c522c80922 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Iterable, List, TypeVar from six import iteritems, itervalues @@ -22,9 +23,13 @@ import attr from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.types import StateMap logger = logging.getLogger(__name__) +# Used for generic functions below +T = TypeVar("T") + @attr.s(slots=True) class StateFilter(object): @@ -233,14 +238,14 @@ class StateFilter(object): return len(self.concrete_types()) - def filter_state(self, state_dict): + def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]: """Returns the state filtered with by this StateFilter Args: - state (dict[tuple[str, str], Any]): The state map to filter + state: The state map to filter Returns: - dict[tuple[str, str], Any]: The filtered state map + The filtered state map """ if self.is_full(): return dict(state_dict) @@ -333,12 +338,12 @@ class StateGroupStorage(object): def __init__(self, hs, stores): self.stores = stores - def get_state_group_delta(self, state_group): + def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. Returns: - Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]): + Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) """ @@ -353,7 +358,7 @@ class StateGroupStorage(object): event_ids (iterable[str]): ids of the events Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: + Deferred[dict[int, StateMap[str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: @@ -410,17 +415,18 @@ class StateGroupStorage(object): for group, event_id_map in iteritems(group_to_ids) } - def _get_state_groups_from_groups(self, groups, state_filter): + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): """Returns the state groups for a given set of groups, filtering on types of state events. Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state + groups: list of state group IDs to query + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) @@ -519,7 +525,9 @@ class StateGroupStorage(object): state_map = yield self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -529,8 +537,7 @@ class StateGroupStorage(object): state_filter (StateFilter): The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) diff --git a/synapse/types.py b/synapse/types.py index cd996c0b5a..65e4d8c181 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -17,6 +17,7 @@ import re import string import sys from collections import namedtuple +from typing import Dict, Tuple, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Sized, Iterable, Container, TypeVar + from typing import Sized, Iterable, Container T_co = TypeVar("T_co", covariant=True) @@ -36,6 +37,12 @@ else: __slots__ = () +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateMap = Dict[Tuple[str, str], T] + + class Requester( namedtuple( "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 7856353002..60f0de70f7 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -15,7 +15,6 @@ import logging import re -from itertools import islice import attr @@ -107,22 +106,6 @@ class Clock(object): raise -def batch_iter(iterable, size): - """batch an iterable up into tuples with a maximum size - - Args: - iterable (iterable): the iterable to slice - size (int): the maximum batch size - - Returns: - an iterator over the chunks - """ - # make sure we can deal with iterables like lists too - sourceiter = iter(iterable) - # call islice until it returns an empty tuple - return iter(lambda: tuple(islice(sourceiter, size)), ()) - - def log_failure(failure, msg, consumeErrors=True): """Creates a function suitable for passing to `Deferred.addErrback` that logs any failures that occur. diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py new file mode 100644 index 0000000000..06faeebe7f --- /dev/null +++ b/synapse/util/iterutils.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import islice +from typing import Iterable, Iterator, Sequence, Tuple, TypeVar + +T = TypeVar("T") + + +def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]: + """batch an iterable up into tuples with a maximum size + + Args: + iterable (iterable): the iterable to slice + size (int): the maximum batch size + + Returns: + an iterator over the chunks + """ + # make sure we can deal with iterables like lists too + sourceiter = iter(iterable) + # call islice until it returns an empty tuple + return iter(lambda: tuple(islice(sourceiter, size)), ()) + + +ISeq = TypeVar("ISeq", bound=Sequence, covariant=True) + + +def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: + """Split the given sequence into chunks of the given size + + The last chunk may be shorter than the given size. + + If the input is empty, no chunks are returned. + """ + return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen)) diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 2705cbe5f8..bb62db4637 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -34,7 +34,7 @@ def load_module(provider): provider_class = getattr(module, clz) try: - provider_config = provider_class.parse_config(provider["config"]) + provider_config = provider_class.parse_config(provider.get("config")) except Exception as e: raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e)) |