diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 505bac1308..cef4439477 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -344,3 +344,15 @@ jobs:
env:
COMPLEMENT_BASE_IMAGE: complement-synapse:latest
working-directory: complement
+
+ # a job which marks all the other jobs as complete, thus allowing PRs to be merged.
+ tests-done:
+ needs:
+ - trial
+ - trial-olddeps
+ - sytest
+ - portdb
+ - complement
+ runs-on: ubuntu-latest
+ steps:
+ - run: "true"
\ No newline at end of file
diff --git a/changelog.d/10332.feature b/changelog.d/10332.feature
new file mode 100644
index 0000000000..091947ff22
--- /dev/null
+++ b/changelog.d/10332.feature
@@ -0,0 +1 @@
+Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric.
diff --git a/changelog.d/10348.misc b/changelog.d/10348.misc
new file mode 100644
index 0000000000..b2275a1350
--- /dev/null
+++ b/changelog.d/10348.misc
@@ -0,0 +1 @@
+Run `pyupgrade` on the codebase.
\ No newline at end of file
diff --git a/changelog.d/10382.misc b/changelog.d/10382.misc
new file mode 100644
index 0000000000..eed2d8552a
--- /dev/null
+++ b/changelog.d/10382.misc
@@ -0,0 +1 @@
+Convert internal type variable syntax to reflect wider ecosystem use.
\ No newline at end of file
diff --git a/changelog.d/10386.removal b/changelog.d/10386.removal
new file mode 100644
index 0000000000..800a6143d7
--- /dev/null
+++ b/changelog.d/10386.removal
@@ -0,0 +1 @@
+The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information.
diff --git a/changelog.d/10404.bugfix b/changelog.d/10404.bugfix
new file mode 100644
index 0000000000..2e095b6402
--- /dev/null
+++ b/changelog.d/10404.bugfix
@@ -0,0 +1 @@
+Responses from `/make_{join,leave,knock}` no longer include signatures, which will turn out to be invalid after events are returned to `/send_{join,leave,knock}`.
diff --git a/changelog.d/10414.bugfix b/changelog.d/10414.bugfix
new file mode 100644
index 0000000000..bfebed8d29
--- /dev/null
+++ b/changelog.d/10414.bugfix
@@ -0,0 +1 @@
+Fix a number of logged errors caused by remote servers being down.
diff --git a/changelog.d/10418.misc b/changelog.d/10418.misc
new file mode 100644
index 0000000000..eed2d8552a
--- /dev/null
+++ b/changelog.d/10418.misc
@@ -0,0 +1 @@
+Convert internal type variable syntax to reflect wider ecosystem use.
\ No newline at end of file
diff --git a/changelog.d/10421.misc b/changelog.d/10421.misc
new file mode 100644
index 0000000000..385cbe07af
--- /dev/null
+++ b/changelog.d/10421.misc
@@ -0,0 +1 @@
+Remove unused `events_by_room` code (tech debt).
diff --git a/changelog.d/10427.feature b/changelog.d/10427.feature
new file mode 100644
index 0000000000..091947ff22
--- /dev/null
+++ b/changelog.d/10427.feature
@@ -0,0 +1 @@
+Add a new version of the R30 phone-home metric, which removes a false impression of retention given by the old R30 metric.
diff --git a/changelog.d/10430.misc b/changelog.d/10430.misc
new file mode 100644
index 0000000000..a017cf4ac9
--- /dev/null
+++ b/changelog.d/10430.misc
@@ -0,0 +1 @@
+Add a github actions job recording success of other jobs.
diff --git a/changelog.d/9884.feature b/changelog.d/9884.feature
new file mode 100644
index 0000000000..525fd2f93c
--- /dev/null
+++ b/changelog.d/9884.feature
@@ -0,0 +1 @@
+Add a module type for the account validity feature.
diff --git a/docs/modules.md b/docs/modules.md
index bec1c06d15..9a430390a4 100644
--- a/docs/modules.md
+++ b/docs/modules.md
@@ -63,7 +63,7 @@ Modules can register web resources onto Synapse's web server using the following
API method:
```python
-def ModuleApi.register_web_resource(path: str, resource: IResource)
+def ModuleApi.register_web_resource(path: str, resource: IResource) -> None
```
The path is the full absolute path to register the resource at. For example, if you
@@ -91,12 +91,17 @@ are split in categories. A single module may implement callbacks from multiple c
and is under no obligation to implement all callbacks from the categories it registers
callbacks for.
+Modules can register callbacks using one of the module API's `register_[...]_callbacks`
+methods. The callback functions are passed to these methods as keyword arguments, with
+the callback name as the argument name and the function as its value. This is demonstrated
+in the example below. A `register_[...]_callbacks` method exists for each module type
+documented in this section.
+
#### Spam checker callbacks
-To register one of the callbacks described in this section, a module needs to use the
-module API's `register_spam_checker_callbacks` method. The callback functions are passed
-to `register_spam_checker_callbacks` as keyword arguments, with the callback name as the
-argument name and the function as its value. This is demonstrated in the example below.
+Spam checker callbacks allow module developers to implement spam mitigation actions for
+Synapse instances. Spam checker callbacks can be registered using the module API's
+`register_spam_checker_callbacks` method.
The available spam checker callbacks are:
@@ -115,7 +120,7 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool
Called when processing an invitation. The module must return a `bool` indicating whether
the inviter can invite the invitee to the given room. Both inviter and invitee are
-represented by their Matrix user ID (i.e. `@alice:example.com`).
+represented by their Matrix user ID (e.g. `@alice:example.com`).
```python
async def user_may_create_room(user: str) -> bool
@@ -181,13 +186,103 @@ The arguments passed to this callback are:
```python
async def check_media_file_for_spam(
file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper",
- file_info: "synapse.rest.media.v1._base.FileInfo"
+ file_info: "synapse.rest.media.v1._base.FileInfo",
) -> bool
```
Called when storing a local or remote file. The module must return a boolean indicating
whether the given file can be stored in the homeserver's media store.
+#### Account validity callbacks
+
+Account validity callbacks allow module developers to add extra steps to verify the
+validity on an account, i.e. see if a user can be granted access to their account on the
+Synapse instance. Account validity callbacks can be registered using the module API's
+`register_account_validity_callbacks` method.
+
+The available account validity callbacks are:
+
+```python
+async def is_user_expired(user: str) -> Optional[bool]
+```
+
+Called when processing any authenticated request (except for logout requests). The module
+can return a `bool` to indicate whether the user has expired and should be locked out of
+their account, or `None` if the module wasn't able to figure it out. The user is
+represented by their Matrix user ID (e.g. `@alice:example.com`).
+
+If the module returns `True`, the current request will be denied with the error code
+`ORG_MATRIX_EXPIRED_ACCOUNT` and the HTTP status code 403. Note that this doesn't
+invalidate the user's access token.
+
+```python
+async def on_user_registration(user: str) -> None
+```
+
+Called after successfully registering a user, in case the module needs to perform extra
+operations to keep track of them. (e.g. add them to a database table). The user is
+represented by their Matrix user ID.
+
+#### Third party rules callbacks
+
+Third party rules callbacks allow module developers to add extra checks to verify the
+validity of incoming events. Third party event rules callbacks can be registered using
+the module API's `register_third_party_rules_callbacks` method.
+
+The available third party rules callbacks are:
+
+```python
+async def check_event_allowed(
+ event: "synapse.events.EventBase",
+ state_events: "synapse.types.StateMap",
+) -> Tuple[bool, Optional[dict]]
+```
+
+**<span style="color:red">
+This callback is very experimental and can and will break without notice. Module developers
+are encouraged to implement `check_event_for_spam` from the spam checker category instead.
+</span>**
+
+Called when processing any incoming event, with the event and a `StateMap`
+representing the current state of the room the event is being sent into. A `StateMap` is
+a dictionary that maps tuples containing an event type and a state key to the
+corresponding state event. For example retrieving the room's `m.room.create` event from
+the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
+The module must return a boolean indicating whether the event can be allowed.
+
+Note that this callback function processes incoming events coming via federation
+traffic (on top of client traffic). This means denying an event might cause the local
+copy of the room's history to diverge from that of remote servers. This may cause
+federation issues in the room. It is strongly recommended to only deny events using this
+callback function if the sender is a local user, or in a private federation in which all
+servers are using the same module, with the same configuration.
+
+If the boolean returned by the module is `True`, it may also tell Synapse to replace the
+event with new data by returning the new event's data as a dictionary. In order to do
+that, it is recommended the module calls `event.get_dict()` to get the current event as a
+dictionary, and modify the returned dictionary accordingly.
+
+Note that replacing the event only works for events sent by local users, not for events
+received over federation.
+
+```python
+async def on_create_room(
+ requester: "synapse.types.Requester",
+ request_content: dict,
+ is_requester_admin: bool,
+) -> None
+```
+
+Called when processing a room creation request, with the `Requester` object for the user
+performing the request, a dictionary representing the room creation request's JSON body
+(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom)
+for a list of possible parameters), and a boolean indicating whether the user performing
+the request is a server admin.
+
+Modules can modify the `request_content` (by e.g. adding events to its `initial_state`),
+or deny the room's creation by raising a `module_api.errors.SynapseError`.
+
+
### Porting an existing module that uses the old interface
In order to port a module that uses Synapse's old module interface, its author needs to:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index a45732a246..853c2f6899 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1310,91 +1310,6 @@ account_threepid_delegates:
#auto_join_rooms_for_guests: false
-## Account Validity ##
-
-# Optional account validity configuration. This allows for accounts to be denied
-# any request after a given period.
-#
-# Once this feature is enabled, Synapse will look for registered users without an
-# expiration date at startup and will add one to every account it found using the
-# current settings at that time.
-# This means that, if a validity period is set, and Synapse is restarted (it will
-# then derive an expiration date from the current validity period), and some time
-# after that the validity period changes and Synapse is restarted, the users'
-# expiration dates won't be updated unless their account is manually renewed. This
-# date will be randomly selected within a range [now + period - d ; now + period],
-# where d is equal to 10% of the validity period.
-#
-account_validity:
- # The account validity feature is disabled by default. Uncomment the
- # following line to enable it.
- #
- #enabled: true
-
- # The period after which an account is valid after its registration. When
- # renewing the account, its validity period will be extended by this amount
- # of time. This parameter is required when using the account validity
- # feature.
- #
- #period: 6w
-
- # The amount of time before an account's expiry date at which Synapse will
- # send an email to the account's email address with a renewal link. By
- # default, no such emails are sent.
- #
- # If you enable this setting, you will also need to fill out the 'email' and
- # 'public_baseurl' configuration sections.
- #
- #renew_at: 1w
-
- # The subject of the email sent out with the renewal link. '%(app)s' can be
- # used as a placeholder for the 'app_name' parameter from the 'email'
- # section.
- #
- # Note that the placeholder must be written '%(app)s', including the
- # trailing 's'.
- #
- # If this is not set, a default value is used.
- #
- #renew_email_subject: "Renew your %(app)s account"
-
- # Directory in which Synapse will try to find templates for the HTML files to
- # serve to the user when trying to renew an account. If not set, default
- # templates from within the Synapse package will be used.
- #
- # The currently available templates are:
- #
- # * account_renewed.html: Displayed to the user after they have successfully
- # renewed their account.
- #
- # * account_previously_renewed.html: Displayed to the user if they attempt to
- # renew their account with a token that is valid, but that has already
- # been used. In this case the account is not renewed again.
- #
- # * invalid_token.html: Displayed to the user when they try to renew an account
- # with an unknown or invalid renewal token.
- #
- # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
- # default template contents.
- #
- # The file name of some of these templates can be configured below for legacy
- # reasons.
- #
- #template_dir: "res/templates"
-
- # A custom file name for the 'account_renewed.html' template.
- #
- # If not set, the file is assumed to be named "account_renewed.html".
- #
- #account_renewed_html_path: "account_renewed.html"
-
- # A custom file name for the 'invalid_token.html' template.
- #
- # If not set, the file is assumed to be named "invalid_token.html".
- #
- #invalid_token_html_path: "invalid_token.html"
-
-
## Metrics ###
# Enable collection and rendering of performance metrics
@@ -2739,19 +2654,6 @@ stats:
# action: allow
-# Server admins can define a Python module that implements extra rules for
-# allowing or denying incoming events. In order to work, this module needs to
-# override the methods defined in synapse/events/third_party_rules.py.
-#
-# This feature is designed to be used in closed federations only, where each
-# participating server enforces the same rules.
-#
-#third_party_event_rules:
-# module: "my_custom_project.SuperRulesSet"
-# config:
-# example_option: 'things'
-
-
## Opentracing ##
# These settings enable opentracing, which implements distributed tracing.
diff --git a/docs/upgrade.md b/docs/upgrade.md
index db0450f563..c8f4a2c171 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -86,6 +86,19 @@ process, for example:
```
+# Upgrading to v1.39.0
+
+## Deprecation of the current third-party rules module interface
+
+The current third-party rules module interface is deprecated in favour of the new generic
+modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer
+to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface)
+to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules)
+to update their configuration once the modules they are using have been updated.
+
+We plan to remove support for the current third-party rules interface in September 2021.
+
+
# Upgrading to v1.38.0
## Re-indexing of `events` table on Postgres databases
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 8916e6fa2f..05699714ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -62,6 +62,7 @@ class Auth:
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
+ self._account_validity_handler = hs.get_account_validity_handler()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache"
@@ -69,9 +70,6 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
- self._account_validity_enabled = (
- hs.config.account_validity.account_validity_enabled
- )
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
@@ -187,12 +185,17 @@ class Auth:
shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
- if self._account_validity_enabled and not allow_expired:
- if await self.store.is_account_expired(
- user_info.user_id, self.clock.time_msec()
+ if not allow_expired:
+ if await self._account_validity_handler.is_user_expired(
+ user_info.user_id
):
+ # Raise the error if either an account validity module has determined
+ # the account has expired, or the legacy account validity
+ # implementation is enabled and determined the account has expired
raise AuthError(
- 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
+ 403,
+ "User account has expired",
+ errcode=Codes.EXPIRED_ACCOUNT,
)
device_id = user_info.device_id
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index b30571fe49..50a02f51f5 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -368,6 +369,7 @@ async def start(hs: "HomeServer"):
module(config=config, api=module_api)
load_legacy_spam_checkers(hs)
+ load_legacy_third_party_event_rules(hs)
# If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index b43d858f59..c3d4992518 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -395,10 +395,8 @@ class GenericWorkerServer(HomeServer):
elif listener.type == "metrics":
if not self.config.enable_metrics:
logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
)
else:
_base.listen_metrics(listener.bind_addresses, listener.port)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7af56ac136..920b34d97b 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -305,10 +305,8 @@ class SynapseHomeServer(HomeServer):
elif listener.type == "metrics":
if not self.config.enable_metrics:
logger.warning(
- (
- "Metrics listener configured, but "
- "enable_metrics is not True!"
- )
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
)
else:
_base.listen_metrics(listener.bind_addresses, listener.port)
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 8f86cecb76..96defac1d2 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -71,6 +71,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
# General statistics
#
+ store = hs.get_datastore()
+
stats["homeserver"] = hs.config.server_name
stats["server_context"] = hs.config.server_context
stats["timestamp"] = now
@@ -79,34 +81,38 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
- stats["total_users"] = await hs.get_datastore().count_all_users()
+ stats["total_users"] = await store.count_all_users()
- total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
+ total_nonbridged_users = await store.count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
- daily_user_type_results = await hs.get_datastore().count_daily_user_type()
+ daily_user_type_results = await store.count_daily_user_type()
for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
- room_count = await hs.get_datastore().get_room_count()
+ room_count = await store.get_room_count()
stats["total_room_count"] = room_count
- stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
- stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
- daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+ stats["daily_active_users"] = await store.count_daily_users()
+ stats["monthly_active_users"] = await store.count_monthly_users()
+ daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms()
stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
- stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
- daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+ stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages()
+ daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages()
stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
- stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
- daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+ stats["daily_active_rooms"] = await store.count_daily_active_rooms()
+ stats["daily_messages"] = await store.count_daily_messages()
+ daily_sent_messages = await store.count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
- r30_results = await hs.get_datastore().count_r30_users()
+ r30_results = await store.count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
+ r30v2_results = await store.count_r30_users()
+ for name, count in r30v2_results.items():
+ stats["r30v2_users_" + name] = count
+
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
@@ -115,8 +121,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
#
# This only reports info about the *main* database.
- stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
- stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
+ stats["database_engine"] = store.db_pool.engine.module.__name__
+ stats["database_server_version"] = store.db_pool.engine.server_version
#
# Logging configuration
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index 957de7f3a6..6be4eafe55 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -18,6 +18,21 @@ class AccountValidityConfig(Config):
section = "account_validity"
def read_config(self, config, **kwargs):
+ """Parses the old account validity config. The config format looks like this:
+
+ account_validity:
+ enabled: true
+ period: 6w
+ renew_at: 1w
+ renew_email_subject: "Renew your %(app)s account"
+ template_dir: "res/templates"
+ account_renewed_html_path: "account_renewed.html"
+ invalid_token_html_path: "invalid_token.html"
+
+ We expect admins to use modules for this feature (which is why it doesn't appear
+ in the sample config file), but we want to keep support for it around for a bit
+ for backwards compatibility.
+ """
account_validity_config = config.get("account_validity") or {}
self.account_validity_enabled = account_validity_config.get("enabled", False)
self.account_validity_renew_by_email_enabled = (
@@ -75,90 +90,3 @@ class AccountValidityConfig(Config):
],
account_validity_template_dir,
)
-
- def generate_config_section(self, **kwargs):
- return """\
- ## Account Validity ##
-
- # Optional account validity configuration. This allows for accounts to be denied
- # any request after a given period.
- #
- # Once this feature is enabled, Synapse will look for registered users without an
- # expiration date at startup and will add one to every account it found using the
- # current settings at that time.
- # This means that, if a validity period is set, and Synapse is restarted (it will
- # then derive an expiration date from the current validity period), and some time
- # after that the validity period changes and Synapse is restarted, the users'
- # expiration dates won't be updated unless their account is manually renewed. This
- # date will be randomly selected within a range [now + period - d ; now + period],
- # where d is equal to 10% of the validity period.
- #
- account_validity:
- # The account validity feature is disabled by default. Uncomment the
- # following line to enable it.
- #
- #enabled: true
-
- # The period after which an account is valid after its registration. When
- # renewing the account, its validity period will be extended by this amount
- # of time. This parameter is required when using the account validity
- # feature.
- #
- #period: 6w
-
- # The amount of time before an account's expiry date at which Synapse will
- # send an email to the account's email address with a renewal link. By
- # default, no such emails are sent.
- #
- # If you enable this setting, you will also need to fill out the 'email' and
- # 'public_baseurl' configuration sections.
- #
- #renew_at: 1w
-
- # The subject of the email sent out with the renewal link. '%(app)s' can be
- # used as a placeholder for the 'app_name' parameter from the 'email'
- # section.
- #
- # Note that the placeholder must be written '%(app)s', including the
- # trailing 's'.
- #
- # If this is not set, a default value is used.
- #
- #renew_email_subject: "Renew your %(app)s account"
-
- # Directory in which Synapse will try to find templates for the HTML files to
- # serve to the user when trying to renew an account. If not set, default
- # templates from within the Synapse package will be used.
- #
- # The currently available templates are:
- #
- # * account_renewed.html: Displayed to the user after they have successfully
- # renewed their account.
- #
- # * account_previously_renewed.html: Displayed to the user if they attempt to
- # renew their account with a token that is valid, but that has already
- # been used. In this case the account is not renewed again.
- #
- # * invalid_token.html: Displayed to the user when they try to renew an account
- # with an unknown or invalid renewal token.
- #
- # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
- # default template contents.
- #
- # The file name of some of these templates can be configured below for legacy
- # reasons.
- #
- #template_dir: "res/templates"
-
- # A custom file name for the 'account_renewed.html' template.
- #
- # If not set, the file is assumed to be named "account_renewed.html".
- #
- #account_renewed_html_path: "account_renewed.html"
-
- # A custom file name for the 'invalid_token.html' template.
- #
- # If not set, the file is assumed to be named "invalid_token.html".
- #
- #invalid_token_html_path: "invalid_token.html"
- """
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index a39d457c56..1ebea88db2 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -64,7 +64,7 @@ def load_appservices(hostname, config_files):
for config_file in config_files:
try:
- with open(config_file, "r") as f:
+ with open(config_file) as f:
appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
if appservice.id in seen_ids:
raise ConfigError(
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index f502ff539e..a3fae02420 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config):
self.third_party_event_rules = load_module(
provider, ("third_party_event_rules",)
)
-
- def generate_config_section(self, **kwargs):
- return """\
- # Server admins can define a Python module that implements extra rules for
- # allowing or denying incoming events. In order to work, this module needs to
- # override the methods defined in synapse/events/third_party_rules.py.
- #
- # This feature is designed to be used in closed federations only, where each
- # participating server enforces the same rules.
- #
- #third_party_event_rules:
- # module: "my_custom_project.SuperRulesSet"
- # config:
- # example_option: 'things'
- """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index fed05ac7be..5679f05e42 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -66,10 +66,8 @@ class TlsConfig(Config):
if self.federation_client_minimum_tls_version == "1.3":
if getattr(SSL, "OP_NO_TLSv1_3", None) is None:
raise ConfigError(
- (
- "federation_client_minimum_tls_version cannot be 1.3, "
- "your OpenSSL does not support it"
- )
+ "federation_client_minimum_tls_version cannot be 1.3, "
+ "your OpenSSL does not support it"
)
# Whitelist of domains to not verify certificates for
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 65dc7a4ed0..0298af4c02 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -291,6 +291,20 @@ class EventBase(metaclass=abc.ABCMeta):
return pdu_json
+ def get_templated_pdu_json(self) -> JsonDict:
+ """
+ Return a JSON object suitable for a templated event, as used in the
+ make_{join,leave,knock} workflow.
+ """
+ # By using _dict directly we don't pull in signatures/unsigned.
+ template_json = dict(self._dict)
+ # The hashes (similar to the signature) need to be recalculated by the
+ # joining/leaving/knocking server after (potentially) modifying the
+ # event.
+ template_json.pop("hashes")
+
+ return template_json
+
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index f7944fd834..7a6eb3e516 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -11,16 +11,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
-from typing import TYPE_CHECKING, Union
-
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
from synapse.server import HomeServer
+logger = logging.getLogger(__name__)
+
+
+CHECK_EVENT_ALLOWED_CALLBACK = Callable[
+ [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
+]
+ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
+CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
+ [str, str, StateMap[EventBase]], Awaitable[bool]
+]
+CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
+ [str, StateMap[EventBase], str], Awaitable[bool]
+]
+
+
+def load_legacy_third_party_event_rules(hs: "HomeServer"):
+ """Wrapper that loads a third party event rules module configured using the old
+ configuration, and registers the hooks they implement.
+ """
+ if hs.config.third_party_event_rules is None:
+ return
+
+ module, config = hs.config.third_party_event_rules
+
+ api = hs.get_module_api()
+ third_party_rules = module(config=config, module_api=api)
+
+ # The known hooks. If a module implements a method which name appears in this set,
+ # we'll want to register it.
+ third_party_event_rules_methods = {
+ "check_event_allowed",
+ "on_create_room",
+ "check_threepid_can_be_invited",
+ "check_visibility_can_be_modified",
+ }
+
+ def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+ # f might be None if the callback isn't implemented by the module. In this
+ # case we don't want to register a callback at all so we return None.
+ if f is None:
+ return None
+
+ # We return a separate wrapper for these methods because, in order to wrap them
+ # correctly, we need to await its result. Therefore it doesn't make a lot of
+ # sense to make it go through the run() wrapper.
+ if f.__name__ == "check_event_allowed":
+
+ # We need to wrap check_event_allowed because its old form would return either
+ # a boolean or a dict, but now we want to return the dict separately from the
+ # boolean.
+ async def wrap_check_event_allowed(
+ event: EventBase,
+ state_events: StateMap[EventBase],
+ ) -> Tuple[bool, Optional[dict]]:
+ # We've already made sure f is not None above, but mypy doesn't do well
+ # across function boundaries so we need to tell it f is definitely not
+ # None.
+ assert f is not None
+
+ res = await f(event, state_events)
+ if isinstance(res, dict):
+ return True, res
+ else:
+ return res, None
+
+ return wrap_check_event_allowed
+
+ if f.__name__ == "on_create_room":
+
+ # We need to wrap on_create_room because its old form would return a boolean
+ # if the room creation is denied, but now we just want it to raise an
+ # exception.
+ async def wrap_on_create_room(
+ requester: Requester, config: dict, is_requester_admin: bool
+ ) -> None:
+ # We've already made sure f is not None above, but mypy doesn't do well
+ # across function boundaries so we need to tell it f is definitely not
+ # None.
+ assert f is not None
+
+ res = await f(requester, config, is_requester_admin)
+ if res is False:
+ raise SynapseError(
+ 403,
+ "Room creation forbidden with these parameters",
+ )
+
+ return wrap_on_create_room
+
+ def run(*args, **kwargs):
+ # mypy doesn't do well across function boundaries so we need to tell it
+ # f is definitely not None.
+ assert f is not None
+
+ return maybe_awaitable(f(*args, **kwargs))
+
+ return run
+
+ # Register the hooks through the module API.
+ hooks = {
+ hook: async_wrapper(getattr(third_party_rules, hook, None))
+ for hook in third_party_event_rules_methods
+ }
+
+ api.register_third_party_rules_callbacks(**hooks)
+
class ThirdPartyEventRules:
"""Allows server admins to provide a Python module implementing an extra
@@ -35,36 +143,65 @@ class ThirdPartyEventRules:
self.store = hs.get_datastore()
- module = None
- config = None
- if hs.config.third_party_event_rules:
- module, config = hs.config.third_party_event_rules
+ self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
+ self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
+ self._check_threepid_can_be_invited_callbacks: List[
+ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+ ] = []
+ self._check_visibility_can_be_modified_callbacks: List[
+ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+ ] = []
+
+ def register_third_party_rules_callbacks(
+ self,
+ check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
+ on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
+ check_threepid_can_be_invited: Optional[
+ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+ ] = None,
+ check_visibility_can_be_modified: Optional[
+ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+ ] = None,
+ ):
+ """Register callbacks from modules for each hook."""
+ if check_event_allowed is not None:
+ self._check_event_allowed_callbacks.append(check_event_allowed)
+
+ if on_create_room is not None:
+ self._on_create_room_callbacks.append(on_create_room)
+
+ if check_threepid_can_be_invited is not None:
+ self._check_threepid_can_be_invited_callbacks.append(
+ check_threepid_can_be_invited,
+ )
- if module is not None:
- self.third_party_rules = module(
- config=config,
- module_api=hs.get_module_api(),
+ if check_visibility_can_be_modified is not None:
+ self._check_visibility_can_be_modified_callbacks.append(
+ check_visibility_can_be_modified,
)
async def check_event_allowed(
self, event: EventBase, context: EventContext
- ) -> Union[bool, dict]:
+ ) -> Tuple[bool, Optional[dict]]:
"""Check if a provided event should be allowed in the given context.
The module can return:
* True: the event is allowed.
* False: the event is not allowed, and should be rejected with M_FORBIDDEN.
- * a dict: replacement event data.
+
+ If the event is allowed, the module can also return a dictionary to use as a
+ replacement for the event.
Args:
event: The event to be checked.
context: The context of the event.
Returns:
- The result from the ThirdPartyRules module, as above
+ The result from the ThirdPartyRules module, as above.
"""
- if self.third_party_rules is None:
- return True
+ # Bail out early without hitting the store if we don't have any callbacks to run.
+ if len(self._check_event_allowed_callbacks) == 0:
+ return True, None
prev_state_ids = await context.get_prev_state_ids()
@@ -77,29 +214,46 @@ class ThirdPartyEventRules:
# the hashes and signatures.
event.freeze()
- return await self.third_party_rules.check_event_allowed(event, state_events)
+ for callback in self._check_event_allowed_callbacks:
+ try:
+ res, replacement_data = await callback(event, state_events)
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+ continue
+
+ # Return if the event shouldn't be allowed or if the module came up with a
+ # replacement dict for the event.
+ if res is False:
+ return res, None
+ elif isinstance(replacement_data, dict):
+ return True, replacement_data
+
+ return True, None
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
- ) -> bool:
- """Intercept requests to create room to allow, deny or update the
- request config.
+ ) -> None:
+ """Intercept requests to create room to maybe deny it (via an exception) or
+ update the request config.
Args:
requester
config: The creation config from the client.
is_requester_admin: If the requester is an admin
-
- Returns:
- Whether room creation is allowed or denied.
"""
-
- if self.third_party_rules is None:
- return True
-
- return await self.third_party_rules.on_create_room(
- requester, config, is_requester_admin
- )
+ for callback in self._on_create_room_callbacks:
+ try:
+ await callback(requester, config, is_requester_admin)
+ except Exception as e:
+ # Don't silence the errors raised by this callback since we expect it to
+ # raise an exception to deny the creation of the room; instead make sure
+ # it's a SynapseError we can send to clients.
+ if not isinstance(e, SynapseError):
+ e = SynapseError(
+ 403, "Room creation forbidden with these parameters"
+ )
+
+ raise e
async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str
@@ -114,15 +268,20 @@ class ThirdPartyEventRules:
Returns:
True if the 3PID can be invited, False if not.
"""
-
- if self.third_party_rules is None:
+ # Bail out early without hitting the store if we don't have any callbacks to run.
+ if len(self._check_threepid_can_be_invited_callbacks) == 0:
return True
state_events = await self._get_state_map_for_room(room_id)
- return await self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events
- )
+ for callback in self._check_threepid_can_be_invited_callbacks:
+ try:
+ if await callback(medium, address, state_events) is False:
+ return False
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+ return True
async def check_visibility_can_be_modified(
self, room_id: str, new_visibility: str
@@ -137,18 +296,20 @@ class ThirdPartyEventRules:
Returns:
True if the room's visibility can be modified, False if not.
"""
- if self.third_party_rules is None:
- return True
-
- check_func = getattr(
- self.third_party_rules, "check_visibility_can_be_modified", None
- )
- if not check_func or not callable(check_func):
+ # Bail out early without hitting the store if we don't have any callback
+ if len(self._check_visibility_can_be_modified_callbacks) == 0:
return True
state_events = await self._get_state_map_for_room(room_id)
- return await check_func(room_id, state_events, new_visibility)
+ for callback in self._check_visibility_can_be_modified_callbacks:
+ try:
+ if await callback(room_id, state_events, new_visibility) is False:
+ return False
+ except Exception as e:
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+ return True
async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
"""Given a room ID, return the state events of that room.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d91f0ff32f..29619aeeb8 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -562,8 +562,7 @@ class FederationServer(FederationBase):
raise IncompatibleRoomVersionError(room_version=room_version)
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
- time_now = self._clock.time_msec()
- return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+ return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
async def on_invite_request(
self, origin: str, content: JsonDict, room_version_id: str
@@ -611,8 +610,7 @@ class FederationServer(FederationBase):
room_version = await self.store.get_room_version_id(room_id)
- time_now = self._clock.time_msec()
- return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+ return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
async def on_send_leave_request(
self, origin: str, content: JsonDict, room_id: str
@@ -659,9 +657,8 @@ class FederationServer(FederationBase):
)
pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
- time_now = self._clock.time_msec()
return {
- "event": pdu.get_pdu_json(time_now),
+ "event": pdu.get_templated_pdu_json(),
"room_version": room_version.identifier,
}
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d800e16912..525f3d39b1 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -38,10 +38,10 @@ class BaseHandler:
"""
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastore() # type: synapse.storage.DataStore
+ self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
- self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
+ self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
@@ -55,12 +55,12 @@ class BaseHandler:
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
- self.admin_redaction_ratelimiter = Ratelimiter(
+ self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
- ) # type: Optional[Ratelimiter]
+ )
else:
self.admin_redaction_ratelimiter = None
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d752cf34f0..078accd634 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -15,9 +15,11 @@
import email.mime.multipart
import email.utils
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
-from synapse.api.errors import StoreError, SynapseError
+from twisted.web.http import Request
+
+from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
@@ -27,6 +29,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+# Types for callbacks to be registered via the module api
+IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
+ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
+# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
+# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
+ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
+ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
+ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
+
class AccountValidityHandler:
def __init__(self, hs: "HomeServer"):
@@ -70,6 +81,99 @@ class AccountValidityHandler:
if hs.config.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
+ self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
+ self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
+ self._on_legacy_send_mail_callback: Optional[
+ ON_LEGACY_SEND_MAIL_CALLBACK
+ ] = None
+ self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
+
+ # The legacy admin requests callback isn't a protected attribute because we need
+ # to access it from the admin servlet, which is outside of this handler.
+ self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
+
+ def register_account_validity_callbacks(
+ self,
+ is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
+ on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+ on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
+ on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
+ on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
+ ):
+ """Register callbacks from module for each hook."""
+ if is_user_expired is not None:
+ self._is_user_expired_callbacks.append(is_user_expired)
+
+ if on_user_registration is not None:
+ self._on_user_registration_callbacks.append(on_user_registration)
+
+ # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
+ # an admin one). As part of moving the feature into a module, we need to change
+ # the path from /_matrix/client/unstable/account_validity/... to
+ # /_synapse/client/account_validity, because:
+ #
+ # * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
+ # * the way we register servlets means that modules can't register resources
+ # under /_matrix/client
+ #
+ # We need to allow for a transition period between the old and new endpoints
+ # in order to allow for clients to update (and for emails to be processed).
+ #
+ # Once the email-account-validity module is loaded, it will take control of account
+ # validity by moving the rows from our `account_validity` table into its own table.
+ #
+ # Therefore, we need to allow modules (in practice just the one implementing the
+ # email-based account validity) to temporarily hook into the legacy endpoints so we
+ # can route the traffic coming into the old endpoints into the module, which is
+ # why we have the following three temporary hooks.
+ if on_legacy_send_mail is not None:
+ if self._on_legacy_send_mail_callback is not None:
+ raise RuntimeError("Tried to register on_legacy_send_mail twice")
+
+ self._on_legacy_send_mail_callback = on_legacy_send_mail
+
+ if on_legacy_renew is not None:
+ if self._on_legacy_renew_callback is not None:
+ raise RuntimeError("Tried to register on_legacy_renew twice")
+
+ self._on_legacy_renew_callback = on_legacy_renew
+
+ if on_legacy_admin_request is not None:
+ if self.on_legacy_admin_request_callback is not None:
+ raise RuntimeError("Tried to register on_legacy_admin_request twice")
+
+ self.on_legacy_admin_request_callback = on_legacy_admin_request
+
+ async def is_user_expired(self, user_id: str) -> bool:
+ """Checks if a user has expired against third-party modules.
+
+ Args:
+ user_id: The user to check the expiry of.
+
+ Returns:
+ Whether the user has expired.
+ """
+ for callback in self._is_user_expired_callbacks:
+ expired = await callback(user_id)
+ if expired is not None:
+ return expired
+
+ if self._account_validity_enabled:
+ # If no module could determine whether the user has expired and the legacy
+ # configuration is enabled, fall back to it.
+ return await self.store.is_account_expired(user_id, self.clock.time_msec())
+
+ return False
+
+ async def on_user_registration(self, user_id: str):
+ """Tell third-party modules about a user's registration.
+
+ Args:
+ user_id: The ID of the newly registered user.
+ """
+ for callback in self._on_user_registration_callbacks:
+ await callback(user_id)
+
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
@@ -95,6 +199,17 @@ class AccountValidityHandler:
Raises:
SynapseError if the user is not set to renew.
"""
+ # If a module supports sending a renewal email from here, do that, otherwise do
+ # the legacy dance.
+ if self._on_legacy_send_mail_callback is not None:
+ await self._on_legacy_send_mail_callback(user_id)
+ return
+
+ if not self._account_validity_renew_by_email_enabled:
+ raise AuthError(
+ 403, "Account renewal via email is disabled on this server."
+ )
+
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
# If this user isn't set to be expired, raise an error.
@@ -209,6 +324,10 @@ class AccountValidityHandler:
token is considered stale. A token is stale if the 'token_used_ts_ms' db column
is non-null.
+ This method exists to support handling the legacy account validity /renew
+ endpoint. If a module implements the on_legacy_renew callback, then this process
+ is delegated to the module instead.
+
Args:
renewal_token: Token sent with the renewal request.
Returns:
@@ -218,6 +337,11 @@ class AccountValidityHandler:
* An int representing the user's expiry timestamp as milliseconds since the
epoch, or 0 if the token was invalid.
"""
+ # If a module supports triggering a renew from here, do that, otherwise do the
+ # legacy dance.
+ if self._on_legacy_renew_callback is not None:
+ return await self._on_legacy_renew_callback(renewal_token)
+
try:
(
user_id,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index d75a8b15c3..bfa7f2c545 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -139,7 +139,7 @@ class AdminHandler(BaseHandler):
to_key = RoomStreamToken(None, stream_ordering)
# Events that we've processed in this room
- written_events = set() # type: Set[str]
+ written_events: Set[str] = set()
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@@ -152,7 +152,7 @@ class AdminHandler(BaseHandler):
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
# events "children".
- unseen_to_child_events = {} # type: Dict[str, Set[str]]
+ unseen_to_child_events: Dict[str, Set[str]] = {}
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 862638cc4f..21a17cd2e8 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -96,7 +96,7 @@ class ApplicationServicesHandler:
self.current_max, limit
)
- events_by_room = {} # type: Dict[str, List[EventBase]]
+ events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@@ -275,7 +275,7 @@ class ApplicationServicesHandler:
async def _handle_presence(
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
- events = [] # type: List[JsonDict]
+ events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
@@ -375,7 +375,7 @@ class ApplicationServicesHandler:
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
- protocols = {} # type: Dict[str, List[JsonDict]]
+ protocols: Dict[str, List[JsonDict]] = {}
# Collect up all the individual protocol responses out of the ASes
for s in services:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index e2ac595a62..22a8552241 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -191,7 +191,7 @@ class AuthHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
+ self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
@@ -296,7 +296,7 @@ class AuthHandler(BaseHandler):
# A mapping of user ID to extra attributes to include in the login
# response.
- self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes]
+ self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
async def validate_user_via_ui_auth(
self,
@@ -500,7 +500,7 @@ class AuthHandler(BaseHandler):
all the stages in any of the permitted flows.
"""
- sid = None # type: Optional[str]
+ sid: Optional[str] = None
authdict = clientdict.pop("auth", {})
if "session" in authdict:
sid = authdict["session"]
@@ -588,9 +588,9 @@ class AuthHandler(BaseHandler):
)
# check auth type currently being presented
- errordict = {} # type: Dict[str, Any]
+ errordict: Dict[str, Any] = {}
if "type" in authdict:
- login_type = authdict["type"] # type: str
+ login_type: str = authdict["type"]
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
@@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
- params = {} # type: Dict[str, Any]
+ params: Dict[str, Any] = {}
for f in public_flows:
for stage in f:
@@ -1530,9 +1530,9 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- user_id_to_verify = await self.get_session_data(
+ user_id_to_verify: str = await self.get_session_data(
session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
- ) # type: str
+ )
idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
user_id_to_verify
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 7346ccfe93..0325f86e20 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -40,7 +40,7 @@ class CasError(Exception):
def __str__(self):
if self.error_description:
- return "{}: {}".format(self.error, self.error_description)
+ return f"{self.error}: {self.error_description}"
return self.error
@@ -171,7 +171,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
- attributes = {} # type: Dict[str, List[Optional[str]]]
+ attributes: Dict[str, List[Optional[str]]] = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 95bdc5902a..46ee834407 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -452,7 +452,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id
)
- hosts = set() # type: Set[str]
+ hosts: Set[str] = set()
if self.hs.is_mine_id(user_id):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
@@ -613,20 +613,20 @@ class DeviceListUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = (
- {}
- ) # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
+ self._pending_updates: Dict[
+ str, List[Tuple[str, str, Iterable[str], JsonDict]]
+ ] = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
- self._seen_updates = ExpiringCache(
+ self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
cache_name="device_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
- ) # type: ExpiringCache[str, Set[str]]
+ )
# Attempt to resync out of sync device lists every 30s.
self._resync_retry_in_progress = False
@@ -755,7 +755,7 @@ class DeviceListUpdater:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
- seen_updates = self._seen_updates.get(user_id, set()) # type: Set[str]
+ seen_updates: Set[str] = self._seen_updates.get(user_id, set())
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 580b941595..679b47f081 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -203,7 +203,7 @@ class DeviceMessageHandler:
log_kv({"number_of_to_device_messages": len(messages)})
set_tag("sender", sender_user_id)
local_messages = {}
- remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+ remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 06d7012bac..d487fee627 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -237,9 +237,9 @@ class DirectoryHandler(BaseHandler):
async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None
if self.hs.is_mine(room_alias):
- result = await self.get_association_from_room_alias(
- room_alias
- ) # type: Optional[RoomAliasMapping]
+ result: Optional[
+ RoomAliasMapping
+ ] = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 3972849d4d..d92370859f 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -115,9 +115,9 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
- device_keys_query = query_body.get(
+ device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
- ) # type: Dict[str, Iterable[str]]
+ )
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -136,7 +136,7 @@ class E2eKeysHandler:
# First get local devices.
# A map of destination -> failure response.
- failures = {} # type: Dict[str, JsonDict]
+ failures: Dict[str, JsonDict] = {}
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@@ -151,11 +151,9 @@ class E2eKeysHandler:
# Now attempt to get any remote devices from our local cache.
# A map of destination -> user ID -> device IDs.
- remote_queries_not_in_cache = (
- {}
- ) # type: Dict[str, Dict[str, Iterable[str]]]
+ remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
- query_list = [] # type: List[Tuple[str, Optional[str]]]
+ query_list: List[Tuple[str, Optional[str]]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
@@ -362,9 +360,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
- local_query = [] # type: List[Tuple[str, Optional[str]]]
+ local_query: List[Tuple[str, Optional[str]]] = []
- result_dict = {} # type: Dict[str, Dict[str, dict]]
+ result_dict: Dict[str, Dict[str, dict]] = {}
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@@ -402,9 +400,9 @@ class E2eKeysHandler:
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
"""Handle a device key query from a federated server"""
- device_keys_query = query_body.get(
+ device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
- ) # type: Dict[str, Optional[List[str]]]
+ )
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@@ -421,8 +419,8 @@ class E2eKeysHandler:
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
- local_query = [] # type: List[Tuple[str, str, str]]
- remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
+ local_query: List[Tuple[str, str, str]] = []
+ remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
@@ -439,8 +437,8 @@ class E2eKeysHandler:
results = await self.store.claim_e2e_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key.
- json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
- failures = {} # type: Dict[str, JsonDict]
+ json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ failures: Dict[str, JsonDict] = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
@@ -768,8 +766,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
- signature_list = [] # type: List[SignatureListItem]
- failures = {} # type: Dict[str, Dict[str, JsonDict]]
+ signature_list: List["SignatureListItem"] = []
+ failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
@@ -930,8 +928,8 @@ class E2eKeysHandler:
Raises:
SynapseError: if the input is malformed
"""
- signature_list = [] # type: List[SignatureListItem]
- failures = {} # type: Dict[str, Dict[str, JsonDict]]
+ signature_list: List["SignatureListItem"] = []
+ failures: Dict[str, Dict[str, JsonDict]] = {}
if not signatures:
return signature_list, failures
@@ -1300,7 +1298,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
- self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
+ self._pending_updates: Dict[str, List[Tuple[JsonDict, JsonDict]]] = {}
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
@@ -1349,7 +1347,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
- device_ids = [] # type: List[str]
+ device_ids: List[str] = []
logger.info("pending updates: %r", pending_updates)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index f134f1e234..4b3f037072 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -93,7 +93,7 @@ class EventStreamHandler(BaseHandler):
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
- to_add = [] # type: List[JsonDict]
+ to_add: List[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
continue
@@ -103,9 +103,9 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
- users = await self.store.get_users_in_room(
+ users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
- ) # type: Iterable[str]
+ )
else:
users = [event.state_key]
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0209aee186..5728719909 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -181,7 +181,7 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
- self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
+ self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._room_backfill = Linearizer("room_backfill")
@@ -368,7 +368,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps = list(ours.values()) # type: List[StateMap[str]]
+ state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it.
del ours
@@ -735,7 +735,7 @@ class FederationHandler(BaseHandler):
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
- (await self.store.get_events(missing_desired_events, allow_rejected=True))
+ await self.store.get_events(missing_desired_events, allow_rejected=True)
)
# check for events which were in the wrong room.
@@ -845,7 +845,7 @@ class FederationHandler(BaseHandler):
# exact key to expect. Otherwise check it matches any key we
# have for that device.
- current_keys = [] # type: Container[str]
+ current_keys: Container[str] = []
if device:
keys = device.get("keys", {}).get("keys", {})
@@ -1185,7 +1185,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
- joined_domains = {} # type: Dict[str, int]
+ joined_domains: Dict[str, int] = {}
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
@@ -1314,7 +1314,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version(room_id)
- event_map = {} # type: Dict[str, EventBase]
+ event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str):
with nested_logging_context(event_id):
@@ -1596,7 +1596,7 @@ class FederationHandler(BaseHandler):
# Ask the remote server to create a valid knock event for us. Once received,
# we sign the event
- params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]]
+ params: Dict[str, Iterable[str]] = {"ver": supported_room_versions}
origin, event, event_format_version = await self._make_and_verify_event(
target_hosts, room_id, knockee, Membership.KNOCK, content, params=params
)
@@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler):
builder=builder
)
- event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler):
# for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK:
- event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -2453,14 +2453,14 @@ class FederationHandler(BaseHandler):
state_sets_d = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
- state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
+ state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
- current_state_ids = {
+ current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
- } # type: StateMap[str]
+ }
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2817,7 +2817,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
- event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
+ event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else:
event_key = None
state_updates = {
@@ -3156,7 +3156,7 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
- last_exception = None # type: Optional[Exception]
+ last_exception: Optional[Exception] = None
# for each public key in the 3pid invite event
for public_key_object in event_auth.get_public_keys(invite_event):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 157f2ff218..1a6c5c64a2 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -214,7 +214,7 @@ class GroupsLocalWorkerHandler:
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
- destinations = {} # type: Dict[str, Set[str]]
+ destinations: Dict[str, Set[str]] = {}
local_users = set()
for user_id in user_ids:
@@ -227,7 +227,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
- failed_results = [] # type: List[str]
+ failed_results: List[str] = []
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 33d16fbf9c..0961dec5ab 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -302,7 +302,7 @@ class IdentityHandler(BaseHandler):
)
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
- url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
+ url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
content = {
"mxid": mxid,
@@ -695,7 +695,7 @@ class IdentityHandler(BaseHandler):
return data["mxid"]
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
- except IOError as e:
+ except OSError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
return None
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 76242865ae..5d49640760 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -46,9 +46,17 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = ResponseCache(
- hs.get_clock(), "initial_sync_cache"
- ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
+ self.snapshot_cache: ResponseCache[
+ Tuple[
+ str,
+ Optional[StreamToken],
+ Optional[StreamToken],
+ str,
+ Optional[int],
+ bool,
+ bool,
+ ]
+ ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5ecac0732c..cf0359556a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -81,7 +81,7 @@ class MessageHandler:
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
- self._scheduled_expiry = None # type: Optional[IDelayedCall]
+ self._scheduled_expiry: Optional[IDelayedCall] = None
if not hs.config.worker_app:
run_as_background_process(
@@ -196,9 +196,7 @@ class MessageHandler:
room_state_events = await self.state_store.get_state_for_events(
[event.event_id], state_filter=state_filter
)
- room_state = room_state_events[
- event.event_id
- ] # type: Mapping[Any, EventBase]
+ room_state: Mapping[Any, EventBase] = room_state_events[event.event_id]
else:
raise AuthError(
403,
@@ -421,9 +419,9 @@ class EventCreationHandler:
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
- self.third_party_event_rules = (
+ self.third_party_event_rules: "ThirdPartyEventRules" = (
self.hs.get_third_party_event_rules()
- ) # type: ThirdPartyEventRules
+ )
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@@ -440,7 +438,7 @@ class EventCreationHandler:
#
# map from room id to time-of-last-attempt.
#
- self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
+ self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {}
# The number of forward extremeities before a dummy event is sent.
self._dummy_events_threshold = hs.config.dummy_events_threshold
@@ -465,9 +463,7 @@ class EventCreationHandler:
# Stores the state groups we've recently added to the joined hosts
# external cache. Note that the timeout must be significantly less than
# the TTL on the external cache.
- self._external_cache_joined_hosts_updates = (
- None
- ) # type: Optional[ExpiringCache]
+ self._external_cache_joined_hosts_updates: Optional[ExpiringCache] = None
if self._external_cache.is_enabled():
self._external_cache_joined_hosts_updates = ExpiringCache(
"_external_cache_joined_hosts_updates",
@@ -953,10 +949,10 @@ class EventCreationHandler:
if requester:
context.app_service = requester.app_service
- third_party_result = await self.third_party_event_rules.check_event_allowed(
+ res, new_content = await self.third_party_event_rules.check_event_allowed(
event, context
)
- if not third_party_result:
+ if res is False:
logger.info(
"Event %s forbidden by third-party rules",
event,
@@ -964,11 +960,11 @@ class EventCreationHandler:
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
- elif isinstance(third_party_result, dict):
+ elif new_content is not None:
# the third-party rules want to replace the event. We'll need to build a new
# event.
event, context = await self._rebuild_event_after_third_party_rules(
- third_party_result, event
+ new_content, event
)
self.validator.validate_new(event, self.config)
@@ -1299,7 +1295,7 @@ class EventCreationHandler:
# Validate a newly added alias or newly added alt_aliases.
original_alias = None
- original_alt_aliases = [] # type: List[str]
+ original_alt_aliases: List[str] = []
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index ee6e41c0e4..eca8f16040 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -72,26 +72,26 @@ _SESSION_COOKIES = [
(b"oidc_session_no_samesite", b"HttpOnly"),
]
+
#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
#: OpenID.Core sec 3.1.3.3.
-Token = TypedDict(
- "Token",
- {
- "access_token": str,
- "token_type": str,
- "id_token": Optional[str],
- "refresh_token": Optional[str],
- "expires_in": int,
- "scope": Optional[str],
- },
-)
+class Token(TypedDict):
+ access_token: str
+ token_type: str
+ id_token: Optional[str]
+ refresh_token: Optional[str]
+ expires_in: int
+ scope: Optional[str]
+
#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
#: there is no real point of doing this in our case.
JWK = Dict[str, str]
+
#: A JWK Set, as per RFC7517 sec 5.
-JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+class JWKS(TypedDict):
+ keys: List[JWK]
class OidcHandler:
@@ -105,9 +105,9 @@ class OidcHandler:
assert provider_confs
self._token_generator = OidcSessionTokenGenerator(hs)
- self._providers = {
+ self._providers: Dict[str, "OidcProvider"] = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
- } # type: Dict[str, OidcProvider]
+ }
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
@@ -178,7 +178,7 @@ class OidcHandler:
# are two.
for cookie_name, _ in _SESSION_COOKIES:
- session = request.getCookie(cookie_name) # type: Optional[bytes]
+ session: Optional[bytes] = request.getCookie(cookie_name)
if session is not None:
break
else:
@@ -255,7 +255,7 @@ class OidcError(Exception):
def __str__(self):
if self.error_description:
- return "{}: {}".format(self.error, self.error_description)
+ return f"{self.error}: {self.error_description}"
return self.error
@@ -277,7 +277,7 @@ class OidcProvider:
self._token_generator = token_generator
self._config = provider
- self._callback_url = hs.config.oidc_callback_url # type: str
+ self._callback_url: str = hs.config.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
@@ -290,7 +290,7 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
- client_secret = None # type: Union[None, str, JwtClientSecret]
+ client_secret: Optional[Union[str, JwtClientSecret]] = None
if provider.client_secret:
client_secret = provider.client_secret
elif provider.client_secret_jwt_key:
@@ -305,7 +305,7 @@ class OidcProvider:
provider.client_id,
client_secret,
provider.client_auth_method,
- ) # type: ClientAuth
+ )
self._client_auth_method = provider.client_auth_method
# cache of metadata for the identity provider (endpoint uris, mostly). This is
@@ -324,7 +324,7 @@ class OidcProvider:
self._allow_existing_users = provider.allow_existing_users
self._http_client = hs.get_proxied_http_client()
- self._server_name = hs.config.server_name # type: str
+ self._server_name: str = hs.config.server_name
# identifier for the external_ids table
self.idp_id = provider.idp_id
@@ -639,7 +639,7 @@ class OidcProvider:
)
logger.warning(description)
# Body was still valid JSON. Might be useful to log it for debugging.
- logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+ logger.warning("Code exchange response: %r", resp)
raise OidcError("server_error", description)
return resp
@@ -1217,10 +1217,12 @@ class OidcSessionData:
ui_auth_session_id = attr.ib(type=str)
-UserAttributeDict = TypedDict(
- "UserAttributeDict",
- {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
-)
+class UserAttributeDict(TypedDict):
+ localpart: Optional[str]
+ display_name: Optional[str]
+ emails: List[str]
+
+
C = TypeVar("C")
@@ -1381,7 +1383,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- emails = [] # type: List[str]
+ emails: List[str] = []
email = render_template_field(self._config.email_template)
if email:
emails.append(email)
@@ -1391,7 +1393,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
- extras = {} # type: Dict[str, str]
+ extras: Dict[str, str] = {}
for key, template in self._config.extra_attributes.items():
try:
extras[key] = template.render(user=userinfo).strip()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1e1186c29e..1dbafd253d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -81,9 +81,9 @@ class PaginationHandler:
self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock()
- self._purges_in_progress_by_room = set() # type: Set[str]
+ self._purges_in_progress_by_room: Set[str] = set()
# map from purge id to PurgeStatus
- self._purges_by_id = {} # type: Dict[str, PurgeStatus]
+ self._purges_by_id: Dict[str, PurgeStatus] = {}
self._event_serializer = hs.get_event_client_serializer()
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 44ed7a0712..016c5df2ca 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -378,14 +378,14 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by user id.
# Empty if _presence_enabled is false.
- self._user_to_num_current_syncs = {} # type: Dict[str, int]
+ self._user_to_num_current_syncs: Dict[str, int] = {}
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
# user_id -> last_sync_ms. Lists the users that have stopped syncing but
# we haven't notified the presence writer of that yet
- self.users_going_offline = {} # type: Dict[str, int]
+ self.users_going_offline: Dict[str, int] = {}
self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@@ -650,7 +650,7 @@ class PresenceHandler(BasePresenceHandler):
# Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted
- self.unpersisted_users_changes = set() # type: Set[str]
+ self.unpersisted_users_changes: Set[str] = set()
hs.get_reactor().addSystemEventTrigger(
"before",
@@ -664,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
- self.user_to_num_current_syncs = {} # type: Dict[str, int]
+ self.user_to_num_current_syncs: Dict[str, int] = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
@@ -674,8 +674,8 @@ class PresenceHandler(BasePresenceHandler):
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
- self.external_process_to_current_syncs = {} # type: Dict[str, Set[str]]
- self.external_process_last_updated_ms = {} # type: Dict[str, int]
+ self.external_process_to_current_syncs: Dict[str, Set[str]] = {}
+ self.external_process_last_updated_ms: Dict[str, int] = {}
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
@@ -1581,9 +1581,7 @@ class PresenceEventSource:
# The set of users that we're interested in and that have had a presence update.
# We'll actually pull the presence updates for these users at the end.
- interested_and_updated_users = (
- set()
- ) # type: Union[Set[str], FrozenSet[str]]
+ interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
if from_key:
# First get all users that have had a presence update
@@ -1950,8 +1948,8 @@ async def get_interested_parties(
A 2-tuple of `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
- room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
- users_to_states = {} # type: Dict[str, List[UserPresenceState]]
+ room_ids_to_states: Dict[str, List[UserPresenceState]] = {}
+ users_to_states: Dict[str, List[UserPresenceState]] = {}
for state in states:
room_ids = await store.get_rooms_for_user(state.user_id)
for room_id in room_ids:
@@ -2063,12 +2061,12 @@ class PresenceFederationQueue:
# stream_id, destinations, user_ids)`. We don't store the full states
# for efficiency, and remote workers will already have the full states
# cached.
- self._queue = [] # type: List[Tuple[int, int, Collection[str], Set[str]]]
+ self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
self._next_id = 1
# Map from instance name to current token
- self._current_tokens = {} # type: Dict[str, int]
+ self._current_tokens: Dict[str, int] = {}
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
@@ -2168,7 +2166,7 @@ class PresenceFederationQueue:
# handle the case where `from_token` stream ID has already been dropped.
start_idx = max(from_token + 1 - self._next_id, -len(self._queue))
- to_send = [] # type: List[Tuple[int, Tuple[str, str]]]
+ to_send: List[Tuple[int, Tuple[str, str]]] = []
limited = False
new_id = upto_token
for _, stream_id, destinations, user_ids in self._queue[start_idx:]:
@@ -2216,7 +2214,7 @@ class PresenceFederationQueue:
if not self._federation:
return
- hosts_to_users = {} # type: Dict[str, Set[str]]
+ hosts_to_users: Dict[str, Set[str]] = {}
for row in rows:
hosts_to_users.setdefault(row.destination, set()).add(row.user_id)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 05b4a97b59..20a033d0ba 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -197,7 +197,7 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
- displayname_to_set = new_displayname # type: Optional[str]
+ displayname_to_set: Optional[str] = new_displayname
if new_displayname == "":
displayname_to_set = None
@@ -286,7 +286,7 @@ class ProfileHandler(BaseHandler):
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
- avatar_url_to_set = new_avatar_url # type: Optional[str]
+ avatar_url_to_set: Optional[str] = new_avatar_url
if new_avatar_url == "":
avatar_url_to_set = None
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 0059ad0f56..283483fc2c 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -98,8 +98,8 @@ class ReceiptsHandler(BaseHandler):
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier."""
- min_batch_id = None # type: Optional[int]
- max_batch_id = None # type: Optional[int]
+ min_batch_id: Optional[int] = None
+ max_batch_id: Optional[int] = None
for receipt in receipts:
res = await self.store.insert_receipt(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 26ef016179..8cf614136e 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -55,15 +55,12 @@ login_counter = Counter(
["guest", "auth_provider"],
)
-LoginDict = TypedDict(
- "LoginDict",
- {
- "device_id": str,
- "access_token": str,
- "valid_until_ms": Optional[int],
- "refresh_token": Optional[str],
- },
-)
+
+class LoginDict(TypedDict):
+ device_id: str
+ access_token: str
+ valid_until_ms: Optional[int]
+ refresh_token: Optional[str]
class RegistrationHandler(BaseHandler):
@@ -77,6 +74,7 @@ class RegistrationHandler(BaseHandler):
self.identity_handler = self.hs.get_identity_handler()
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
+ self._account_validity_handler = hs.get_account_validity_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self._server_name = hs.hostname
@@ -700,6 +698,10 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
+ # Only call the account validity module(s) on the main process, to avoid
+ # repeating e.g. database writes on all of the workers.
+ await self._account_validity_handler.on_user_registration(user_id)
+
async def register_device(
self,
user_id: str,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 579b1b93c5..370561e549 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -87,7 +87,7 @@ class RoomCreationHandler(BaseHandler):
self.config = hs.config
# Room state based off defined presets
- self._presets_dict = {
+ self._presets_dict: Dict[str, Dict[str, Any]] = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": HistoryVisibility.SHARED,
@@ -109,7 +109,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
- } # type: Dict[str, Dict[str, Any]]
+ }
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -127,9 +127,9 @@ class RoomCreationHandler(BaseHandler):
# If a user tries to update the same room multiple times in quick
# succession, only process the first attempt and return its result to
# subsequent requests
- self._upgrade_response_cache = ResponseCache(
+ self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
- ) # type: ResponseCache[Tuple[str, str]]
+ )
self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -377,10 +377,10 @@ class RoomCreationHandler(BaseHandler):
if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
- creation_content = {
+ creation_content: JsonDict = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
- } # type: JsonDict
+ }
# Check if old room was non-federatable
@@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler):
else:
is_requester_admin = await self.auth.is_server_admin(requester.user)
- # Check whether the third party rules allows/changes the room create
- # request.
- event_allowed = await self.third_party_event_rules.on_create_room(
+ # Let the third party rules modify the room creation config if needed, or abort
+ # the room creation entirely with an exception.
+ await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
- if not event_allowed:
- raise SynapseError(
- 403, "You are not permitted to create rooms", Codes.FORBIDDEN
- )
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
@@ -936,7 +932,7 @@ class RoomCreationHandler(BaseHandler):
etype=EventTypes.PowerLevels, content=pl_content
)
else:
- power_level_content = {
+ power_level_content: JsonDict = {
"users": {creator_id: 100},
"users_default": 0,
"events": {
@@ -955,7 +951,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
- } # type: JsonDict
+ }
if config["original_invitees_have_ops"]:
for invitee in invite_list:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 1c2af01abb..b18557da34 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -48,12 +48,12 @@ class RoomListHandler(BaseHandler):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
- self.response_cache = ResponseCache(
- hs.get_clock(), "room_list"
- ) # type: ResponseCache[Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]]
- self.remote_response_cache = ResponseCache(
- hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
- ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
+ self.response_cache: ResponseCache[
+ Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]
+ ] = ResponseCache(hs.get_clock(), "room_list")
+ self.remote_response_cache: ResponseCache[
+ Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
+ ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
async def get_local_public_room_list(
self,
@@ -140,10 +140,10 @@ class RoomListHandler(BaseHandler):
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
- bounds = (
+ bounds: Optional[Tuple[int, str]] = (
batch_token.last_joined_members,
batch_token.last_room_id,
- ) # type: Optional[Tuple[int, str]]
+ )
forwards = batch_token.direction_is_forward
has_batch_token = True
else:
@@ -183,7 +183,7 @@ class RoomListHandler(BaseHandler):
results = [build_room_entry(r) for r in results]
- response = {} # type: JsonDict
+ response: JsonDict = {}
num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
@@ -384,7 +384,11 @@ class RoomListHandler(BaseHandler):
):
logger.debug("Falling back to locally-filtered /publicRooms")
else:
- raise # Not an error that should trigger a fallback.
+ # Not an error that should trigger a fallback.
+ raise SynapseError(502, "Failed to fetch room list")
+ except RequestSendFailed:
+ # Not an error that should trigger a fallback.
+ raise SynapseError(502, "Failed to fetch room list")
# if we reach this point, then we fall back to the situation where
# we currently don't support searching across federation, so we have
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 80ba65b9e0..e6e71e9729 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -83,7 +83,7 @@ class SamlHandler(BaseHandler):
self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
- self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
+ self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@@ -372,7 +372,7 @@ class SamlHandler(BaseHandler):
DOT_REPLACE_PATTERN = re.compile(
- ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+ "[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)
)
@@ -386,10 +386,10 @@ def dot_replace_for_mxid(username: str) -> str:
return username
-MXID_MAPPER_MAP = {
+MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
-} # type: Dict[str, Callable[[str], str]]
+}
@attr.s
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 4e718d3f63..8226d6f5a1 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -192,7 +192,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
- historical_room_ids = [] # type: List[str]
+ historical_room_ids: List[str] = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -216,9 +216,9 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
# Holds result of grouping by room, if applicable
- room_groups = {} # type: Dict[str, JsonDict]
+ room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable
- sender_group = {} # type: Dict[str, JsonDict]
+ sender_group: Dict[str, JsonDict] = {}
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@@ -262,7 +262,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
- room_events = [] # type: List[EventBase]
+ room_events: List[EventBase] = []
i = 0
pagination_token = batch_token
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 366e6211e5..5f7d4602bd 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -90,14 +90,14 @@ class SpaceSummaryHandler:
room_queue = deque((_RoomQueueEntry(room_id, ()),))
# rooms we have already processed
- processed_rooms = set() # type: Set[str]
+ processed_rooms: Set[str] = set()
# events we have already processed. We don't necessarily have their event ids,
# so instead we key on (room id, state key)
- processed_events = set() # type: Set[Tuple[str, str]]
+ processed_events: Set[Tuple[str, str]] = set()
- rooms_result = [] # type: List[JsonDict]
- events_result = [] # type: List[JsonDict]
+ rooms_result: List[JsonDict] = []
+ events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS:
queue_entry = room_queue.popleft()
@@ -272,10 +272,10 @@ class SpaceSummaryHandler:
# the set of rooms that we should not walk further. Initialise it with the
# excluded-rooms list; we will add other rooms as we process them so that
# we do not loop.
- processed_rooms = set(exclude_rooms) # type: Set[str]
+ processed_rooms: Set[str] = set(exclude_rooms)
- rooms_result = [] # type: List[JsonDict]
- events_result = [] # type: List[JsonDict]
+ rooms_result: List[JsonDict] = []
+ events_result: List[JsonDict] = []
while room_queue and len(rooms_result) < MAX_ROOMS:
room_id = room_queue.popleft()
@@ -353,7 +353,7 @@ class SpaceSummaryHandler:
max_children = MAX_ROOMS_PER_SPACE
now = self._clock.time_msec()
- events_result = [] # type: List[JsonDict]
+ events_result: List[JsonDict] = []
for edge_event in itertools.islice(child_events, max_children):
events_result.append(
await self._event_serializer.serialize_event(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 0b297e54c4..1b855a685c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -202,10 +202,10 @@ class SsoHandler:
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
# a map from session id to session data
- self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
+ self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {}
# map from idp_id to SsoIdentityProvider
- self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
+ self._identity_providers: Dict[str, SsoIdentityProvider] = {}
self._consent_at_registration = hs.config.consent.user_consent_at_registration
@@ -296,7 +296,7 @@ class SsoHandler:
)
# if the client chose an IdP, use that
- idp = None # type: Optional[SsoIdentityProvider]
+ idp: Optional[SsoIdentityProvider] = None
if idp_id:
idp = self._identity_providers.get(idp_id)
if not idp:
@@ -669,9 +669,9 @@ class SsoHandler:
remote_user_id,
)
- user_id_to_verify = await self._auth_handler.get_session_data(
+ user_id_to_verify: str = await self._auth_handler.get_session_data(
ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
- ) # type: str
+ )
if not user_id:
logger.warning(
@@ -793,7 +793,7 @@ class SsoHandler:
session.use_display_name = use_display_name
emails_from_idp = set(session.emails)
- filtered_emails = set() # type: Set[str]
+ filtered_emails: Set[str] = set()
# we iterate through the list rather than just building a set conjunction, so
# that we can log attempts to use unknown addresses
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 814d08efcb..3fd89af2a4 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -49,7 +49,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
- self.pos = None # type: Optional[int]
+ self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -131,10 +131,10 @@ class StatsHandler:
mapping from room/user ID to changes in the various fields.
"""
- room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
- user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
+ room_to_stats_deltas: Dict[str, CounterType[str]] = {}
+ user_to_stats_deltas: Dict[str, CounterType[str]] = {}
- room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
+ room_to_state_updates: Dict[str, Dict[str, Any]] = {}
for delta in deltas:
typ = delta["type"]
@@ -164,7 +164,7 @@ class StatsHandler:
)
continue
- event_content = {} # type: JsonDict
+ event_content: JsonDict = {}
if event_id is not None:
event = await self.store.get_event(event_id, allow_none=True)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 151484e21e..8a1b79bd95 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -279,12 +279,14 @@ class SyncHandler:
self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
- self.lazy_loaded_members_cache = ExpiringCache(
+ self.lazy_loaded_members_cache: ExpiringCache[
+ Tuple[str, Optional[str]], LruCache[str, str]
+ ] = ExpiringCache(
"lazy_loaded_members_cache",
self.clock,
max_len=0,
expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
- ) # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
+ )
async def wait_for_sync_for_user(
self,
@@ -441,7 +443,7 @@ class SyncHandler:
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
- ephemeral_by_room = {} # type: JsonDict
+ ephemeral_by_room: JsonDict = {}
for event in typing:
# we want to exclude the room_id from the event, but modifying the
@@ -503,7 +505,7 @@ class SyncHandler:
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
- current_state_ids = frozenset() # type: FrozenSet[str]
+ current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
current_state_ids_map = await self.store.get_current_state_ids(
room_id
@@ -784,9 +786,9 @@ class SyncHandler:
def get_lazy_loaded_members_cache(
self, cache_key: Tuple[str, Optional[str]]
) -> LruCache[str, str]:
- cache = self.lazy_loaded_members_cache.get(
+ cache: Optional[LruCache[str, str]] = self.lazy_loaded_members_cache.get(
cache_key
- ) # type: Optional[LruCache[str, str]]
+ )
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
@@ -985,7 +987,7 @@ class SyncHandler:
if t[0] == EventTypes.Member:
cache.set(t[1], event_id)
- state = {} # type: Dict[str, EventBase]
+ state: Dict[str, EventBase] = {}
if state_ids:
state = await self.store.get_events(list(state_ids.values()))
@@ -1089,8 +1091,8 @@ class SyncHandler:
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
- one_time_key_counts = {} # type: JsonDict
- unused_fallback_key_types = [] # type: List[str]
+ one_time_key_counts: JsonDict = {}
+ unused_fallback_key_types: List[str] = []
if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
@@ -1438,7 +1440,7 @@ class SyncHandler:
)
if block_all_room_ephemeral:
- ephemeral_by_room = {} # type: Dict[str, List[JsonDict]]
+ ephemeral_by_room: Dict[str, List[JsonDict]] = {}
else:
now_token, ephemeral_by_room = await self.ephemeral_by_room(
sync_result_builder,
@@ -1469,7 +1471,7 @@ class SyncHandler:
# If there is ignored users account data and it matches the proper type,
# then use it.
- ignored_users = frozenset() # type: FrozenSet[str]
+ ignored_users: FrozenSet[str] = frozenset()
if ignored_account_data:
ignored_users_data = ignored_account_data.get("ignored_users", {})
if isinstance(ignored_users_data, dict):
@@ -1587,7 +1589,7 @@ class SyncHandler:
user_id, since_token.room_key, now_token.room_key
)
- mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]]
+ mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
@@ -1600,7 +1602,7 @@ class SyncHandler:
logger.debug(
"Membership changes in %s: [%s]",
room_id,
- ", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
+ ", ".join("%s (%s)" % (e.event_id, e.membership) for e in events),
)
non_joins = [e for e in events if e.membership != Membership.JOIN]
@@ -1723,7 +1725,7 @@ class SyncHandler:
# This is all screaming out for a refactor, as the logic here is
# subtle and the moving parts numerous.
if leave_event.internal_metadata.is_out_of_band_membership():
- batch_events = [leave_event] # type: Optional[List[EventBase]]
+ batch_events: Optional[List[EventBase]] = [leave_event]
else:
batch_events = None
@@ -1972,7 +1974,7 @@ class SyncHandler:
room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
- summary = {} # type: Optional[JsonDict]
+ summary: Optional[JsonDict] = {}
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
@@ -1996,7 +1998,7 @@ class SyncHandler:
)
if room_builder.rtype == "joined":
- unread_notifications = {} # type: Dict[str, int]
+ unread_notifications: Dict[str, int] = {}
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c0a8364755..0cb651a400 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -68,11 +68,11 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
- self._room_serials = {} # type: Dict[str, int]
+ self._room_serials: Dict[str, int] = {}
# map room IDs to sets of users currently typing
- self._room_typing = {} # type: Dict[str, Set[str]]
+ self._room_typing: Dict[str, Set[str]] = {}
- self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
+ self._member_last_federation_poke: Dict[RoomMember, int] = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
@@ -217,7 +217,7 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
# clock time we expect to stop
- self._member_typing_until = {} # type: Dict[RoomMember, int]
+ self._member_typing_until: Dict[RoomMember, int] = {}
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
@@ -405,9 +405,9 @@ class TypingWriterHandler(FollowerTypingHandler):
if last_id == current_id:
return [], current_id, False
- changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
- last_id
- ) # type: Optional[Iterable[str]]
+ changed_rooms: Optional[
+ Iterable[str]
+ ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
if changed_rooms is None:
changed_rooms = self._room_serials
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index dacc4f3076..6edb1da50a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -52,7 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.search_all_users = hs.config.user_directory_search_all_users
self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
- self.pos = None # type: Optional[int]
+ self.pos: Optional[int] = None
# Guard to ensure we only process deltas one at a time
self._is_processing = False
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7a6a1717de..f7193e60bd 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -172,7 +172,7 @@ class ProxyAgent(_AgentBase):
"""
uri = uri.strip()
if not _VALID_URI.match(uri):
- raise ValueError("Invalid URI {!r}".format(uri))
+ raise ValueError(f"Invalid URI {uri!r}")
parsed_uri = URI.fromBytes(uri)
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 3b0a38124e..190084e8aa 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -384,7 +384,7 @@ class SynapseRequest(Request):
# authenticated (e.g. and admin is puppetting a user) then we log both.
requester, authenticated_entity = self.get_authenticated_entity()
if authenticated_entity:
- requester = "{}.{}".format(authenticated_entity, requester)
+ requester = f"{authenticated_entity}.{requester}"
self.site.access_logger.log(
log_level,
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 185844f188..ecd51f1b4a 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -374,7 +374,7 @@ def init_tracer(hs: "HomeServer"):
config = JaegerConfig(
config=hs.config.jaeger_config,
- service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()),
+ service_name=f"{hs.config.server_name} {hs.get_instance_name()}",
scope_manager=LogContextScopeManager(hs.config),
metrics_factory=PrometheusMetricsFactory(),
)
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 7e49d0d02c..bb9bcb5592 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -34,7 +34,7 @@ from twisted.web.resource import Resource
from synapse.util import caches
-CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
+CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
INF = float("inf")
@@ -55,8 +55,8 @@ def floatToGoString(d):
# Go switches to exponents sooner than Python.
# We only need to care about positive values for le/quantile.
if d > 0 and dot > 6:
- mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.")
- return "{0}e+0{1}".format(mantissa, dot - 1)
+ mantissa = f"{s[0]}.{s[1:dot]}{s[dot + 1 :]}".rstrip("0.")
+ return f"{mantissa}e+0{dot - 1}"
return s
@@ -65,7 +65,7 @@ def sample_line(line, name):
labelstr = "{{{0}}}".format(
",".join(
[
- '{0}="{1}"'.format(
+ '{}="{}"'.format(
k,
v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""),
)
@@ -78,10 +78,8 @@ def sample_line(line, name):
timestamp = ""
if line.timestamp is not None:
# Convert to milliseconds.
- timestamp = " {0:d}".format(int(float(line.timestamp) * 1000))
- return "{0}{1} {2}{3}\n".format(
- name, labelstr, floatToGoString(line.value), timestamp
- )
+ timestamp = f" {int(float(line.timestamp) * 1000):d}"
+ return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
def generate_latest(registry, emit_help=False):
@@ -118,12 +116,12 @@ def generate_latest(registry, emit_help=False):
# Output in the old format for compatibility.
if emit_help:
output.append(
- "# HELP {0} {1}\n".format(
+ "# HELP {} {}\n".format(
mname,
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
)
)
- output.append("# TYPE {0} {1}\n".format(mname, mtype))
+ output.append(f"# TYPE {mname} {mtype}\n")
om_samples: Dict[str, List[str]] = {}
for s in metric.samples:
@@ -143,13 +141,13 @@ def generate_latest(registry, emit_help=False):
for suffix, lines in sorted(om_samples.items()):
if emit_help:
output.append(
- "# HELP {0}{1} {2}\n".format(
+ "# HELP {}{} {}\n".format(
metric.name,
suffix,
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
)
)
- output.append("# TYPE {0}{1} gauge\n".format(metric.name, suffix))
+ output.append(f"# TYPE {metric.name}{suffix} gauge\n")
output.extend(lines)
# Get rid of the weird colon things while we're at it
@@ -163,12 +161,12 @@ def generate_latest(registry, emit_help=False):
# Also output in the new format, if it's different.
if emit_help:
output.append(
- "# HELP {0} {1}\n".format(
+ "# HELP {} {}\n".format(
mnewname,
metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
)
)
- output.append("# TYPE {0} {1}\n".format(mnewname, mtype))
+ output.append(f"# TYPE {mnewname} {mtype}\n")
for s in metric.samples:
# Get rid of the OpenMetrics specific samples (we should already have
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 4455fa71a8..3a14260752 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -137,8 +137,7 @@ class _Collector:
_background_process_db_txn_duration,
_background_process_db_sched_duration,
):
- for r in m.collect():
- yield r
+ yield from m.collect()
REGISTRY.register(_Collector())
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 308f045700..1259fc2d90 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -12,18 +12,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import email.utils
import logging
-from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+)
+
+import jinja2
from twisted.internet import defer
from twisted.web.resource import IResource
from synapse.events import EventBase
from synapse.http.client import SimpleHttpClient
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ respond_with_html,
+)
+from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.util import Clock
+from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi
are loaded into Synapse.
"""
-__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
+__all__ = [
+ "errors",
+ "make_deferred_yieldable",
+ "parse_json_object_from_request",
+ "respond_with_html",
+ "run_in_background",
+ "cached",
+ "UserID",
+ "DatabasePool",
+ "LoggingTransaction",
+ "DirectServeHtmlResource",
+ "DirectServeJsonResource",
+ "ModuleApi",
+]
logger = logging.getLogger(__name__)
@@ -52,12 +89,28 @@ class ModuleApi:
self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"]
self._state = hs.get_state_handler()
+ self._clock: Clock = hs.get_clock()
+ self._send_email_handler = hs.get_send_email_handler()
+
+ try:
+ app_name = self._hs.config.email_app_name
+
+ self._from_string = self._hs.config.email_notif_from % {"app": app_name}
+ except (KeyError, TypeError):
+ # If substitution failed (which can happen if the string contains
+ # placeholders other than just "app", or if the type of the placeholder is
+ # not a string), fall back to the bare strings.
+ self._from_string = self._hs.config.email_notif_from
+
+ self._raw_from = email.utils.parseaddr(self._from_string)[1]
# We expose these as properties below in order to attach a helpful docstring.
self._http_client: SimpleHttpClient = hs.get_simple_http_client()
self._public_room_list_manager = PublicRoomListManager(hs)
self._spam_checker = hs.get_spam_checker()
+ self._account_validity_handler = hs.get_account_validity_handler()
+ self._third_party_event_rules = hs.get_third_party_event_rules()
#################################################################################
# The following methods should only be called during the module's initialisation.
@@ -67,6 +120,16 @@ class ModuleApi:
"""Registers callbacks for spam checking capabilities."""
return self._spam_checker.register_callbacks
+ @property
+ def register_account_validity_callbacks(self):
+ """Registers callbacks for account validity capabilities."""
+ return self._account_validity_handler.register_account_validity_callbacks
+
+ @property
+ def register_third_party_rules_callbacks(self):
+ """Registers callbacks for third party event rules capabilities."""
+ return self._third_party_event_rules.register_third_party_rules_callbacks
+
def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path.
@@ -101,22 +164,56 @@ class ModuleApi:
"""
return self._public_room_list_manager
- def get_user_by_req(self, req, allow_guest=False):
+ @property
+ def public_baseurl(self) -> str:
+ """The configured public base URL for this homeserver."""
+ return self._hs.config.public_baseurl
+
+ @property
+ def email_app_name(self) -> str:
+ """The application name configured in the homeserver's configuration."""
+ return self._hs.config.email.email_app_name
+
+ async def get_user_by_req(
+ self,
+ req: SynapseRequest,
+ allow_guest: bool = False,
+ allow_expired: bool = False,
+ ) -> Requester:
"""Check the access_token provided for a request
Args:
- req (twisted.web.server.Request): Incoming HTTP request
- allow_guest (bool): True if guest users should be allowed. If this
+ req: Incoming HTTP request
+ allow_guest: True if guest users should be allowed. If this
is False, and the access token is for a guest user, an
AuthError will be thrown
+ allow_expired: True if expired users should be allowed. If this
+ is False, and the access token is for an expired user, an
+ AuthError will be thrown
+
Returns:
- twisted.internet.defer.Deferred[synapse.types.Requester]:
- the requester for this request
+ The requester for this request
+
Raises:
- synapse.api.errors.AuthError: if no user by that token exists,
+ InvalidClientCredentialsError: if no user by that token exists,
or the token is invalid.
"""
- return self._auth.get_user_by_req(req, allow_guest)
+ return await self._auth.get_user_by_req(
+ req,
+ allow_guest,
+ allow_expired=allow_expired,
+ )
+
+ async def is_user_admin(self, user_id: str) -> bool:
+ """Checks if a user is a server admin.
+
+ Args:
+ user_id: The Matrix ID of the user to check.
+
+ Returns:
+ True if the user is a server admin, False otherwise.
+ """
+ return await self._store.is_server_admin(UserID.from_string(user_id))
def get_qualified_user_id(self, username):
"""Qualify a user id, if necessary
@@ -134,6 +231,32 @@ class ModuleApi:
return username
return UserID(username, self._hs.hostname).to_string()
+ async def get_profile_for_user(self, localpart: str) -> ProfileInfo:
+ """Look up the profile info for the user with the given localpart.
+
+ Args:
+ localpart: The localpart to look up profile information for.
+
+ Returns:
+ The profile information (i.e. display name and avatar URL).
+ """
+ return await self._store.get_profileinfo(localpart)
+
+ async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
+ """Look up the threepids (email addresses and phone numbers) associated with the
+ given Matrix user ID.
+
+ Args:
+ user_id: The Matrix user ID to look up threepids for.
+
+ Returns:
+ A list of threepids, each threepid being represented by a dictionary
+ containing a "medium" key which value is "email" for email addresses and
+ "msisdn" for phone numbers, and an "address" key which value is the
+ threepid's address.
+ """
+ return await self._store.user_get_threepids(user_id)
+
def check_user_exists(self, user_id):
"""Check if user exists.
@@ -464,6 +587,88 @@ class ModuleApi:
presence_events, destination
)
+ def looping_background_call(
+ self,
+ f: Callable,
+ msec: float,
+ *args,
+ desc: Optional[str] = None,
+ **kwargs,
+ ):
+ """Wraps a function as a background process and calls it repeatedly.
+
+ Waits `msec` initially before calling `f` for the first time.
+
+ Args:
+ f: The function to call repeatedly. f can be either synchronous or
+ asynchronous, and must follow Synapse's logcontext rules.
+ More info about logcontexts is available at
+ https://matrix-org.github.io/synapse/latest/log_contexts.html
+ msec: How long to wait between calls in milliseconds.
+ *args: Positional arguments to pass to function.
+ desc: The background task's description. Default to the function's name.
+ **kwargs: Key arguments to pass to function.
+ """
+ if desc is None:
+ desc = f.__name__
+
+ if self._hs.config.run_background_tasks:
+ self._clock.looping_call(
+ run_as_background_process,
+ msec,
+ desc,
+ f,
+ *args,
+ **kwargs,
+ )
+ else:
+ logger.warning(
+ "Not running looping call %s as the configuration forbids it",
+ f,
+ )
+
+ async def send_mail(
+ self,
+ recipient: str,
+ subject: str,
+ html: str,
+ text: str,
+ ):
+ """Send an email on behalf of the homeserver.
+
+ Args:
+ recipient: The email address for the recipient.
+ subject: The email's subject.
+ html: The email's HTML content.
+ text: The email's text content.
+ """
+ await self._send_email_handler.send_email(
+ email_address=recipient,
+ subject=subject,
+ app_name=self.email_app_name,
+ html=html,
+ text=text,
+ )
+
+ def read_templates(
+ self,
+ filenames: List[str],
+ custom_template_directory: Optional[str] = None,
+ ) -> List[jinja2.Template]:
+ """Read and load the content of the template files at the given location.
+ By default, Synapse will look for these templates in its configured template
+ directory, but another directory to search in can be provided.
+
+ Args:
+ filenames: The name of the template files to look for.
+ custom_template_directory: An additional directory to look for the files in.
+
+ Returns:
+ A list containing the loaded templates, with the orders matching the one of
+ the filenames parameter.
+ """
+ return self._hs.config.read_templates(filenames, custom_template_directory)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index 02bbb0be39..98ea911a81 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -14,5 +14,9 @@
"""Exception types which are exposed as part of the stable module API"""
-from synapse.api.errors import RedirectException, SynapseError # noqa: F401
+from synapse.api.errors import ( # noqa: F401
+ InvalidClientCredentialsError,
+ RedirectException,
+ SynapseError,
+)
from synapse.config._base import ConfigError # noqa: F401
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 2519ad76db..85621f33ef 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,10 +62,6 @@ class PusherPool:
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self._account_validity_enabled = (
- hs.config.account_validity.account_validity_enabled
- )
-
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
@@ -89,6 +85,8 @@ class PusherPool:
# map from user id to app_id:pushkey to pusher
self.pushers: Dict[str, Dict[str, Pusher]] = {}
+ self._account_validity_handler = hs.get_account_validity_handler()
+
def start(self) -> None:
"""Starts the pushers off in a background process."""
if not self._should_start_pushers:
@@ -238,12 +236,9 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
- if self._account_validity_enabled:
- expired = await self.store.is_account_expired(
- u, self.clock.time_msec()
- )
- if expired:
- continue
+ expired = await self._account_validity_handler.is_user_expired(u)
+ if expired:
+ continue
if u in self.pushers:
for p in self.pushers[u].values():
@@ -268,12 +263,9 @@ class PusherPool:
for u in users_affected:
# Don't push if the user account has expired
- if self._account_validity_enabled:
- expired = await self.store.is_account_expired(
- u, self.clock.time_msec()
- )
- if expired:
- continue
+ expired = await self._account_validity_handler.is_user_expired(u)
+ if expired:
+ continue
if u in self.pushers:
for p in self.pushers[u].values():
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 3c51a742bf..40ee33646c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
# Get the room ID from the identifier.
try:
- remote_room_hosts = [
+ remote_room_hosts: Optional[List[str]] = [
x.decode("ascii") for x in request.args[b"server_name"]
- ] # type: Optional[List[str]]
+ ]
except Exception:
remote_room_hosts = None
room_id, remote_room_hosts = await self.resolve_room_id(
@@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
- event_filter = Filter(
- json_decoder.decode(filter_json)
- ) # type: Optional[Filter]
+ event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
else:
event_filter = None
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7d75564758..589e47fa47 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
- self.nonces = {} # type: Dict[str, int]
+ self.nonces: Dict[str, int] = {}
self.hs = hs
def _clear_old_nonces(self):
@@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- body = parse_json_object_from_request(request)
+ if self.account_activity_handler.on_legacy_admin_request_callback:
+ expiration_ts = await (
+ self.account_activity_handler.on_legacy_admin_request_callback(request)
+ )
+ else:
+ body = parse_json_object_from_request(request)
- if "user_id" not in body:
- raise SynapseError(400, "Missing property 'user_id' in the request body")
+ if "user_id" not in body:
+ raise SynapseError(
+ 400,
+ "Missing property 'user_id' in the request body",
+ )
- expiration_ts = await self.account_activity_handler.renew_account_for_user(
- body["user_id"],
- body.get("expiration_ts"),
- not body.get("enable_renewal_emails", True),
- )
+ expiration_ts = await self.account_activity_handler.renew_account_for_user(
+ body["user_id"],
+ body.get("expiration_ts"),
+ not body.get("enable_renewal_emails", True),
+ )
res = {"expiration_ts": expiration_ts}
return 200, res
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index cbcb60fe31..11567bf32c 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -44,19 +44,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-LoginResponse = TypedDict(
- "LoginResponse",
- {
- "user_id": str,
- "access_token": str,
- "home_server": str,
- "expires_in_ms": Optional[int],
- "refresh_token": Optional[str],
- "device_id": str,
- "well_known": Optional[Dict[str, Any]],
- },
- total=False,
-)
+class LoginResponse(TypedDict, total=False):
+ user_id: str
+ access_token: str
+ home_server: str
+ expires_in_ms: Optional[int]
+ refresh_token: Optional[str]
+ device_id: str
+ well_known: Optional[Dict[str, Any]]
class LoginRestServlet(RestServlet):
@@ -121,7 +116,7 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- sso_flow = {
+ sso_flow: JsonDict = {
"type": LoginRestServlet.SSO_TYPE,
"identity_providers": [
_get_auth_flow_dict_for_idp(
@@ -129,7 +124,7 @@ class LoginRestServlet(RestServlet):
)
for idp in self._sso_handler.get_identity_providers().values()
],
- } # type: JsonDict
+ }
if self._msc2858_enabled:
# backwards-compatibility support for clients which don't
@@ -150,9 +145,7 @@ class LoginRestServlet(RestServlet):
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- flows.extend(
- ({"type": t} for t in self.auth_handler.get_supported_login_types())
- )
+ flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
@@ -447,7 +440,7 @@ def _get_auth_flow_dict_for_idp(
use_unstable_brands: whether we should use brand identifiers suitable
for the unstable API
"""
- e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
+ e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
@@ -561,7 +554,7 @@ class SsoRedirectServlet(RestServlet):
finish_request(request)
return
- args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
sso_url = await self._sso_handler.handle_redirect_request(
request,
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index ebf4e32230..31a1193cd3 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
- limit = int(content.get("limit", 100)) # type: Optional[int]
+ limit: Optional[int] = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
- event_filter = Filter(
- json_decoder.decode(filter_json)
- ) # type: Optional[Filter]
+ event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
- event_filter = Filter(
- json_decoder.decode(filter_json)
- ) # type: Optional[Filter]
+ event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
else:
event_filter = None
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 2d1ad3d3fb..3ebe401861 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -14,7 +14,7 @@
import logging
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet
@@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet):
)
async def on_POST(self, request):
- if not self.account_validity_renew_by_email_enabled:
- raise AuthError(
- 403, "Account renewal via email is disabled on this server."
- )
-
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index f8dcee603c..d537d811d8 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester, message_type, content["messages"]
)
- response = (200, {}) # type: Tuple[int, dict]
+ response: Tuple[int, dict] = (200, {})
return response
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index e52570cd8e..4282e2b228 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource):
has_consented = False
public_version = username == ""
if not public_version:
- args = request.args # type: Dict[bytes, List[bytes]]
+ args: Dict[bytes, List[bytes]] = request.args
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource):
"""
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
- args = request.args # type: Dict[bytes, List[bytes]]
+ args: Dict[bytes, List[bytes]] = request.args
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index d56a1ae482..63a40b1852 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource):
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
- query = {server.decode("ascii"): {}} # type: dict
+ query: dict = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
# Note that the value is unused.
- cache_misses = {} # type: Dict[str, Dict[str, int]]
+ cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
diff --git a/synapse/rest/media/v1/__init__.py b/synapse/rest/media/v1/__init__.py
index d20186bbd0..3dd16d4bb5 100644
--- a/synapse/rest/media/v1/__init__.py
+++ b/synapse/rest/media/v1/__init__.py
@@ -17,7 +17,7 @@ import PIL.Image
# check for JPEG support.
try:
PIL.Image._getdecoder("rgb", "jpeg", None)
-except IOError as e:
+except OSError as e:
if str(e).startswith("decoder jpeg not available"):
raise Exception(
"FATAL: jpeg codec not supported. Install pillow correctly! "
@@ -32,7 +32,7 @@ except Exception:
# check for PNG support.
try:
PIL.Image._getdecoder("rgb", "zip", None)
-except IOError as e:
+except OSError as e:
if str(e).startswith("decoder zip not available"):
raise Exception(
"FATAL: zip codec not supported. Install pillow correctly! "
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 0fb4cd81f1..90364ebcf7 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
- postpath = request.postpath # type: List[bytes] # type: ignore
+ postpath: List[bytes] = request.postpath # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 21c43c340c..4f702f890c 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -78,16 +78,16 @@ class MediaRepository:
Thumbnailer.set_limits(self.max_image_pixels)
- self.primary_base_path = hs.config.media_store_path # type: str
- self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
+ self.primary_base_path: str = hs.config.media_store_path
+ self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
- self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
- self.recently_accessed_locals = set() # type: Set[str]
+ self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
+ self.recently_accessed_locals: Set[str] = set()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -711,7 +711,7 @@ class MediaRepository:
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
- thumbnails = {} # type: Dict[Tuple[int, int, str], str]
+ thumbnails: Dict[Tuple[int, int, str], str] = {}
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index c7fd97c46c..56cdc1b4ed 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -191,7 +191,7 @@ class MediaStorage:
for provider in self.storage_providers:
for path in paths:
- res = await provider.fetch(path, file_info) # type: Any
+ res: Any = await provider.fetch(path, file_info)
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
@@ -233,7 +233,7 @@ class MediaStorage:
os.makedirs(dirname)
for provider in self.storage_providers:
- res = await provider.fetch(path, file_info) # type: Any
+ res: Any = await provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0adfb1a70f..8e7fead3a2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource):
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
- self._cache = ExpiringCache(
+ self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR,
- ) # type: ExpiringCache[str, ObservableDeferred]
+ )
if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call(
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
- url_to_download = url # type: Optional[str]
+ url_to_download: Optional[str] = url
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
- og = {} # type: Dict[str, Optional[str]]
+ og: Dict[str, Optional[str]] = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 62dc4aae2d..146adca8f1 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource):
errcode=Codes.TOO_LARGE,
)
- args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
upload_name_bytes = parse_bytes_from_args(args, "filename")
if upload_name_bytes:
try:
- upload_name = upload_name_bytes.decode("utf8") # type: Optional[str]
+ upload_name: Optional[str] = upload_name_bytes.decode("utf8")
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
@@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion
try:
- content = request.content # type: IO # type: ignore
+ content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content(
media_type, upload_name, content, content_length, requester.user
)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index 9b002cc15e..ab24ec0a8e 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
use_display_name = parse_boolean(request, "use_display_name", default=False)
try:
- emails_to_use = [
+ emails_to_use: List[str] = [
val.decode("utf-8") for val in request.args.get(b"use_email", [])
- ] # type: List[str]
+ ]
except ValueError:
raise SynapseError(400, "Query parameter use_email must be utf-8")
except SynapseError as e:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index f80d822c12..ccf9ac51ef 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -907,7 +907,7 @@ class DatabasePool:
# The sort is to ensure that we don't rely on dictionary iteration
# order.
keys, vals = zip(
- *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+ *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
)
for k in keys:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 50e7ddd735..c55508867d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -203,9 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_messages_for_device", delete_messages_for_device_txn
)
- log_kv(
- {"message": "deleted {} messages for device".format(count), "count": count}
- )
+ log_kv({"message": f"deleted {count} messages for device", "count": count})
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ec8579b9ad..a396a201d4 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2010,10 +2010,6 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- events_by_room: Dict[str, List[EventBase]] = {}
- for ev in events:
- events_by_room.setdefault(ev.room_id, []).append(ev)
-
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 66ad363bfb..e70d3649ff 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -27,8 +27,11 @@ from synapse.util import json_encoder
_DEFAULT_CATEGORY_ID = ""
_DEFAULT_ROLE_ID = ""
+
# A room in a group.
-_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
+class _RoomInGroup(TypedDict):
+ room_id: str
+ is_public: bool
class GroupServerWorkerStore(SQLBaseStore):
@@ -92,6 +95,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_public": False # Whether this is a public room or not
}
"""
+
# TODO: Pagination
def _get_rooms_in_group_txn(txn):
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index e3a544d9b2..dc0bbc56ac 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -316,6 +316,135 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+ async def count_r30v2_users(self) -> Dict[str, int]:
+ """
+ Counts the number of 30 day retained users, defined as users that:
+ - Appear more than once in the past 60 days
+ - Have more than 30 days between the most and least recent appearances that
+ occurred in the past 60 days.
+
+ (This is the second version of this metric, hence R30'v2')
+
+ Returns:
+ A mapping from client type to the number of 30-day retained users for that client.
+
+ The dict keys are:
+ - "all" (a combined number of users across any and all clients)
+ - "android" (Element Android)
+ - "ios" (Element iOS)
+ - "electron" (Element Desktop)
+ - "web" (any web application -- it's not possible to distinguish Element Web here)
+ """
+
+ def _count_r30v2_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
+ one_day_from_now_in_secs = now + 86400
+
+ # This is the 'per-platform' count.
+ sql = """
+ SELECT
+ client_type,
+ count(client_type)
+ FROM
+ (
+ SELECT
+ user_id,
+ CASE
+ WHEN
+ LOWER(user_agent) LIKE '%%riot%%' OR
+ LOWER(user_agent) LIKE '%%element%%'
+ THEN CASE
+ WHEN
+ LOWER(user_agent) LIKE '%%electron%%'
+ THEN 'electron'
+ WHEN
+ LOWER(user_agent) LIKE '%%android%%'
+ THEN 'android'
+ WHEN
+ LOWER(user_agent) LIKE '%%ios%%'
+ THEN 'ios'
+ ELSE 'unknown'
+ END
+ WHEN
+ LOWER(user_agent) LIKE '%%mozilla%%' OR
+ LOWER(user_agent) LIKE '%%gecko%%'
+ THEN 'web'
+ ELSE 'unknown'
+ END as client_type
+ FROM
+ user_daily_visits
+ WHERE
+ timestamp > ?
+ AND
+ timestamp < ?
+ GROUP BY
+ user_id,
+ client_type
+ HAVING
+ max(timestamp) - min(timestamp) > ?
+ ) AS temp
+ GROUP BY
+ client_type
+ ;
+ """
+
+ # We initialise all the client types to zero, so we get an explicit
+ # zero if they don't appear in the query results
+ results = {"ios": 0, "android": 0, "web": 0, "electron": 0}
+ txn.execute(
+ sql,
+ (
+ sixty_days_ago_in_secs * 1000,
+ one_day_from_now_in_secs * 1000,
+ thirty_days_in_secs * 1000,
+ ),
+ )
+
+ for row in txn:
+ if row[0] == "unknown":
+ continue
+ results[row[0]] = row[1]
+
+ # This is the 'all users' count.
+ sql = """
+ SELECT COUNT(*) FROM (
+ SELECT
+ 1
+ FROM
+ user_daily_visits
+ WHERE
+ timestamp > ?
+ AND
+ timestamp < ?
+ GROUP BY
+ user_id
+ HAVING
+ max(timestamp) - min(timestamp) > ?
+ ) AS r30_users
+ """
+
+ txn.execute(
+ sql,
+ (
+ sixty_days_ago_in_secs * 1000,
+ one_day_from_now_in_secs * 1000,
+ thirty_days_in_secs * 1000,
+ ),
+ )
+ row = txn.fetchone()
+ if row is None:
+ results["all"] = 0
+ else:
+ results["all"] = row[0]
+
+ return results
+
+ return await self.db_pool.runInteraction(
+ "count_r30v2_users", _count_r30v2_users
+ )
+
def _get_start_of_day(self):
"""
Returns millisecond unixtime for start of UTC day.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4d82c4c26d..68f1b40ea6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -649,7 +649,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_to_memberships = await self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
- users_in_room.update((row for row in event_to_memberships.values() if row))
+ users_in_room.update(row for row in event_to_memberships.values() if row)
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 82a7686df0..61392b9639 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -639,7 +639,7 @@ def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
def executescript(txn: Cursor, schema_path: str) -> None:
- with open(schema_path, "r") as f:
+ with open(schema_path) as f:
execute_statements_from_stream(txn, f)
diff --git a/synapse/types.py b/synapse/types.py
index fad23c8700..429bb013d2 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -577,10 +577,10 @@ class RoomStreamToken:
entries = []
for name, pos in self.instance_map.items():
instance_id = await store.get_id_for_instance(name)
- entries.append("{}.{}".format(instance_id, pos))
+ entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries)
- return "m{}~{}".format(self.stream, encoded_map)
+ return f"m{self.stream}~{encoded_map}"
else:
return "s%d" % (self.stream,)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index efeba0cb96..5c65d187b6 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -90,8 +90,7 @@ def enumerate_leaves(node, depth):
yield node
else:
for n in node.values():
- for m in enumerate_leaves(n, depth - 1):
- yield m
+ yield from enumerate_leaves(n, depth - 1)
P = TypeVar("P")
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index a6df81ebff..4138931e7b 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d):
"""
if isinstance(d, TreeCacheNode):
for value_d in d.values():
- for value in iterate_tree_cache_entry(value_d):
- yield value
+ yield from iterate_tree_cache_entry(value_d)
else:
yield d
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index 31b24dd188..d8532411c2 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -31,13 +31,13 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# If pidfile already exists, we should read pid from there; to overwrite it, if
# locking will fail, because locking attempt somehow purges the file contents.
if os.path.isfile(pid_file):
- with open(pid_file, "r") as pid_fh:
+ with open(pid_file) as pid_fh:
old_pid = pid_fh.read()
# Create a lockfile so that only one instance of this daemon is running at any time.
try:
lock_fh = open(pid_file, "w")
- except IOError:
+ except OSError:
print("Unable to create the pidfile.")
sys.exit(1)
@@ -45,7 +45,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# Try to get an exclusive lock on the file. This will fail if another process
# has the file locked.
fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB)
- except IOError:
+ except OSError:
print("Unable to lock on the pidfile.")
# We need to overwrite the pidfile if we got here.
#
@@ -113,7 +113,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
try:
lock_fh.write("%s" % (os.getpid()))
lock_fh.flush()
- except IOError:
+ except OSError:
logger.error("Unable to write pid to the pidfile.")
print("Unable to write pid to the pidfile.")
sys.exit(1)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 1dc6b90275..17532059e9 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -96,7 +96,7 @@ async def filter_events_for_client(
if isinstance(ignored_users_dict, dict):
ignore_list = frozenset(ignored_users_dict.keys())
- erased_senders = await storage.main.are_users_erased((e.sender for e in events))
+ erased_senders = await storage.main.are_users_erased(e.sender for e in events)
if filter_send_to_client:
room_ids = {e.room_id for e in events}
@@ -353,7 +353,7 @@ async def filter_events_for_server(
)
if not check_history_visibility_only:
- erased_senders = await storage.main.are_users_erased((e.sender for e in events))
+ erased_senders = await storage.main.are_users_erased(e.sender for e in events)
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index 2da6ba4dde..5527e278db 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -1,9 +1,11 @@
import synapse
+from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.rest.client.v1 import login, room
from tests import unittest
from tests.unittest import HomeserverTestCase
+FIVE_MINUTES_IN_SECONDS = 300
ONE_DAY_IN_SECONDS = 86400
@@ -151,3 +153,243 @@ class PhoneHomeTestCase(HomeserverTestCase):
# *Now* the user appears in R30.
r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
self.assertEqual(r30_results, {"all": 1, "unknown": 1})
+
+
+class PhoneHomeR30V2TestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def _advance_to(self, desired_time_secs: float):
+ now = self.hs.get_clock().time()
+ assert now < desired_time_secs
+ self.reactor.advance(desired_time_secs - now)
+
+ def make_homeserver(self, reactor, clock):
+ hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock)
+
+ # We don't want our tests to actually report statistics, so check
+ # that it's not enabled
+ assert not hs.config.report_stats
+
+ # This starts the needed data collection that we rely on to calculate
+ # R30v2 metrics.
+ start_phone_stats_home(hs)
+ return hs
+
+ def test_r30v2_minimum_usage(self):
+ """
+ Tests the minimum amount of interaction necessary for the R30v2 metric
+ to consider a user 'retained'.
+ """
+
+ # Register a user, log it in, create a room and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token)
+ self.helper.send(room_id, "message", tok=access_token)
+ first_post_at = self.hs.get_clock().time()
+
+ # Give time for user_daily_visits table to be updated.
+ # (user_daily_visits is updated every 5 minutes using a looping call.)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ store = self.hs.get_datastore()
+
+ # Check the R30 results do not count that user.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Advance 31 days.
+ # (R30v2 includes users with **more** than 30 days between the two visits,
+ # and user_daily_visits records the timestamp as the start of the day.)
+ self.reactor.advance(31 * ONE_DAY_IN_SECONDS)
+ # Also advance 5 minutes to let another user_daily_visits update occur
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # (Make sure the user isn't somehow counted by this point.)
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Send a message (this counts as activity)
+ self.helper.send(room_id, "message2", tok=access_token)
+
+ # We have to wait a few minutes for the user_daily_visits table to
+ # be updated by a background process.
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # *Now* the user is counted.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Advance to JUST under 60 days after the user's first post
+ self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS - 5)
+
+ # Check the user is still counted.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Advance into the next day. The user's first activity is now more than 60 days old.
+ self._advance_to(first_post_at + 60 * ONE_DAY_IN_SECONDS + 5)
+
+ # Check the user is now no longer counted in R30.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ def test_r30v2_user_must_be_retained_for_at_least_a_month(self):
+ """
+ Tests that a newly-registered user must be retained for a whole month
+ before appearing in the R30v2 statistic, even if they post every day
+ during that time!
+ """
+
+ # set a custom user-agent to impersonate Element/Android.
+ headers = (
+ (
+ "User-Agent",
+ "Element/1.1 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
+ ),
+ )
+
+ # Register a user and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!", custom_headers=headers)
+ room_id = self.helper.create_room_as(
+ room_creator=user_id, tok=access_token, custom_headers=headers
+ )
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # Give time for user_daily_visits table to be updated.
+ # (user_daily_visits is updated every 5 minutes using a looping call.)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ store = self.hs.get_datastore()
+
+ # Check the user does not contribute to R30 yet.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ for _ in range(30):
+ # This loop posts a message every day for 30 days
+ self.reactor.advance(ONE_DAY_IN_SECONDS - FIVE_MINUTES_IN_SECONDS)
+ self.helper.send(
+ room_id, "I'm still here", tok=access_token, custom_headers=headers
+ )
+
+ # give time for user_daily_visits to update
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # Notice that the user *still* does not contribute to R30!
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # advance yet another day with more activity
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+ self.helper.send(
+ room_id, "Still here!", tok=access_token, custom_headers=headers
+ )
+
+ # give time for user_daily_visits to update
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # *Now* the user appears in R30.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "android": 1, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ def test_r30v2_returning_dormant_users_not_counted(self):
+ """
+ Tests that dormant users (users inactive for a long time) do not
+ contribute to R30v2 when they return for just a single day.
+ This is a key difference between R30 and R30v2.
+ """
+
+ # set a custom user-agent to impersonate Element/iOS.
+ headers = (
+ (
+ "User-Agent",
+ "Riot/1.4 (iPhone; iOS 13; Scale/4.00)",
+ ),
+ )
+
+ # Register a user and send a message
+ user_id = self.register_user("u1", "secret!")
+ access_token = self.login("u1", "secret!", custom_headers=headers)
+ room_id = self.helper.create_room_as(
+ room_creator=user_id, tok=access_token, custom_headers=headers
+ )
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # the user goes inactive for 2 months
+ self.reactor.advance(60 * ONE_DAY_IN_SECONDS)
+
+ # the user returns for one day, perhaps just to check out a new feature
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # Give time for user_daily_visits table to be updated.
+ # (user_daily_visits is updated every 5 minutes using a looping call.)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ store = self.hs.get_datastore()
+
+ # Check that the user does not contribute to R30v2, even though it's been
+ # more than 30 days since registration.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
+
+ # Check that this is a situation where old R30 differs:
+ # old R30 DOES count this as 'retained'.
+ r30_results = self.get_success(store.count_r30_users())
+ self.assertEqual(r30_results, {"all": 1, "ios": 1})
+
+ # Now we want to check that the user will still be able to appear in
+ # R30v2 as long as the user performs some other activity between
+ # 30 and 60 days later.
+ self.reactor.advance(32 * ONE_DAY_IN_SECONDS)
+ self.helper.send(room_id, "message", tok=access_token, custom_headers=headers)
+
+ # (give time for tables to update)
+ self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
+
+ # Check the user now satisfies the requirements to appear in R30v2.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
+ )
+
+ # Advance to 59.5 days after the user's first R30v2-eligible activity.
+ self.reactor.advance(27.5 * ONE_DAY_IN_SECONDS)
+
+ # Check the user still appears in R30v2.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 1, "ios": 1, "android": 0, "electron": 0, "web": 0}
+ )
+
+ # Advance to 60.5 days after the user's first R30v2-eligible activity.
+ self.reactor.advance(ONE_DAY_IN_SECONDS)
+
+ # Check the user no longer appears in R30v2.
+ r30_results = self.get_success(store.count_r30v2_users())
+ self.assertEqual(
+ r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0}
+ )
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index c5e1c5458b..28dd47a28b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -16,17 +16,19 @@ from typing import Dict
from unittest.mock import Mock
from synapse.events import EventBase
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import Requester, StateMap
+from synapse.util.frozenutils import unfreeze
from tests import unittest
thread_local = threading.local()
-class ThirdPartyRulesTestModule:
+class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: ModuleApi):
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
@@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
return config
-def current_rules_module() -> ThirdPartyRulesTestModule:
- return thread_local.rules_module
+class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ super().__init__(config, module_api)
+
+ def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return False
+
+
+class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ super().__init__(config, module_api)
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ d = event.get_dict()
+ content = unfreeze(event.content)
+ content["foo"] = "bar"
+ d["content"] = content
+ return d
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def default_config(self):
- config = super().default_config()
- config["third_party_event_rules"] = {
- "module": __name__ + ".ThirdPartyRulesTestModule",
- "config": {},
- }
- return config
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ load_legacy_third_party_event_rules(hs)
+
+ return hs
def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ # Some tests might prevent room creation on purpose.
+ try:
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ except Exception:
+ pass
def test_third_party_rules(self):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
@@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
- return ev.type != "foo.bar.forbidden"
+ return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
- current_rules_module().check_event_allowed = callback
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+ callback
+ ]
channel = self.make_request(
"PUT",
@@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
ev.content = {"x": "y"}
- return True
+ return True, None
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"500", channel.result)
+ # check_event_allowed has some error handling, so it shouldn't 500 just because a
+ # module did something bad.
+ self.assertEqual(channel.code, 200, channel.result)
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "x")
def test_modify_event(self):
"""The module can return a modified version of the event"""
@@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
async def check(ev: EventBase, state):
d = ev.get_dict()
d["content"] = {"x": "y"}
- return d
+ return True, d
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"msgtype": "m.text",
"body": d["content"]["body"].upper(),
}
- return d
+ return True, d
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# Send an event, then edit it.
channel = self.make_request(
@@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self):
- """Tests that the module can send an event into a room via the module api"""
+ """Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
"body": "Hello!",
@@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"sender": self.user_id,
}
event: EventBase = self.get_success(
- current_rules_module().module_api.create_and_send_event_into_room(
- event_dict
- )
+ self.hs.get_module_api().create_and_send_event_into_room(event_dict)
)
self.assertEquals(event.sender, self.user_id)
self.assertEquals(event.room_id, self.room_id)
self.assertEquals(event.type, "m.room.message")
self.assertEquals(event.content, content)
+
+ @unittest.override_config(
+ {
+ "third_party_event_rules": {
+ "module": __name__ + ".LegacyChangeEvents",
+ "config": {},
+ }
+ }
+ )
+ def test_legacy_check_event_allowed(self):
+ """Tests that the wrapper for legacy check_event_allowed callbacks works
+ correctly.
+ """
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
+ {
+ "msgtype": "m.text",
+ "body": "Original body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+
+ self.assertIn("foo", channel.json_body["content"].keys())
+ self.assertEqual(channel.json_body["content"]["foo"], "bar")
+
+ @unittest.override_config(
+ {
+ "third_party_event_rules": {
+ "module": __name__ + ".LegacyDenyNewRooms",
+ "config": {},
+ }
+ }
+ )
+ def test_legacy_on_create_room(self):
+ """Tests that the wrapper for legacy on_create_room callbacks works
+ correctly.
+ """
+ self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 69798e95c3..fc2d35596e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -19,7 +19,7 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Mapping, MutableMapping, Optional
+from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
from unittest.mock import patch
import attr
@@ -53,6 +53,9 @@ class RestHelper:
tok: str = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
) -> str:
"""
Create a room.
@@ -87,6 +90,7 @@ class RestHelper:
"POST",
path,
json.dumps(content).encode("utf8"),
+ custom_headers=custom_headers,
)
assert channel.result["code"] == b"%d" % expect_code, channel.result
@@ -175,14 +179,30 @@ class RestHelper:
self.auth_user_id = temp_id
- def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+ def send(
+ self,
+ room_id,
+ body=None,
+ txn_id=None,
+ tok=None,
+ expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ):
if body is None:
body = "body_text_here"
content = {"msgtype": "m.text", "body": body}
return self.send_event(
- room_id, "m.room.message", content, txn_id, tok, expect_code
+ room_id,
+ "m.room.message",
+ content,
+ txn_id,
+ tok,
+ expect_code,
+ custom_headers=custom_headers,
)
def send_event(
@@ -193,6 +213,9 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -207,6 +230,7 @@ class RestHelper:
"PUT",
path,
json.dumps(content or {}).encode("utf8"),
+ custom_headers=custom_headers,
)
assert (
diff --git a/tests/test_state.py b/tests/test_state.py
index 780eba823c..e5488df1ac 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -168,6 +168,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler",
"get_clock",
"get_state_resolution_handler",
+ "get_account_validity_handler",
"hostname",
]
)
diff --git a/tests/unittest.py b/tests/unittest.py
index c6d9064423..3eec9c4d5b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase):
user_id = channel.json_body["user_id"]
return user_id
- def login(self, username, password, device_id=None):
+ def login(
+ self,
+ username,
+ password,
+ device_id=None,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ):
"""
Log in a user, and get an access token. Requires the Login API be
registered.
@@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase):
body["device_id"] = device_id
channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+ "POST",
+ "/_matrix/client/r0/login",
+ json.dumps(body).encode("utf8"),
+ custom_headers=custom_headers,
)
self.assertEqual(channel.code, 200, channel.result)
|