diff --git a/CHANGES.md b/CHANGES.md
index 58973d2a5d..8bd7825089 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,11 +1,11 @@
-Synapse 1.22.0 (2020-10-30)
+Synapse 1.22.1 (2020-10-30)
===========================
Bugfixes
--------
-- Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broken in v1.22.0. ([\#8676](https://github.com/matrix-org/synapse/issues/8676))
-- Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules. ([\#8678](https://github.com/matrix-org/synapse/issues/8678))
+- Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broke in v1.22.0. ([\#8676](https://github.com/matrix-org/synapse/issues/8676))
+- Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules. Broke in v1.22.0. ([\#8678](https://github.com/matrix-org/synapse/issues/8678))
Synapse 1.22.0 (2020-10-27)
diff --git a/changelog.d/8595.misc b/changelog.d/8595.misc
new file mode 100644
index 0000000000..24fab65cda
--- /dev/null
+++ b/changelog.d/8595.misc
@@ -0,0 +1 @@
+Implement and use an @lru_cache decorator.
diff --git a/changelog.d/8635.doc b/changelog.d/8635.doc
new file mode 100644
index 0000000000..00fb1e61a7
--- /dev/null
+++ b/changelog.d/8635.doc
@@ -0,0 +1 @@
+Improve the sample configuration for single sign-on providers.
diff --git a/changelog.d/8676.bugfix b/changelog.d/8676.bugfix
deleted file mode 100644
index df16c72761..0000000000
--- a/changelog.d/8676.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug where an appservice may not be forwarded events for a room it was recently invited to. Broken in v1.22.0.
diff --git a/changelog.d/8678.bugfix b/changelog.d/8678.bugfix
deleted file mode 100644
index 0508d8f109..0000000000
--- a/changelog.d/8678.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules.
diff --git a/changelog.d/8682.bugfix b/changelog.d/8682.bugfix
new file mode 100644
index 0000000000..e61276aa05
--- /dev/null
+++ b/changelog.d/8682.bugfix
@@ -0,0 +1 @@
+Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories.
diff --git a/changelog.d/8688.misc b/changelog.d/8688.misc
new file mode 100644
index 0000000000..bef8dc425a
--- /dev/null
+++ b/changelog.d/8688.misc
@@ -0,0 +1 @@
+Abstract some invite-related code in preparation for landing knocking.
\ No newline at end of file
diff --git a/changelog.d/8690.misc b/changelog.d/8690.misc
new file mode 100644
index 0000000000..0f38ba1f5d
--- /dev/null
+++ b/changelog.d/8690.misc
@@ -0,0 +1 @@
+Fail tests if they do not await coroutines.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index de4ad50458..6cd95707af 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1683,10 +1683,8 @@ trusted_key_servers:
## Single sign-on integration ##
-# Enable SAML2 for registration and login. Uses pysaml2.
-#
-# At least one of `sp_config` or `config_path` must be set in this section to
-# enable SAML login.
+# The following settings can be used to make Synapse use a single sign-on
+# provider for authentication, instead of its internal password database.
#
# You will probably also want to set the following options to `false` to
# disable the regular login/registration flows:
@@ -1695,6 +1693,11 @@ trusted_key_servers:
#
# You will also want to investigate the settings under the "sso" configuration
# section below.
+
+# Enable SAML2 for registration and login. Uses pysaml2.
+#
+# At least one of `sp_config` or `config_path` must be set in this section to
+# enable SAML login.
#
# Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
@@ -1710,40 +1713,42 @@ saml2_config:
# so it is not normally necessary to specify them unless you need to
# override them.
#
- #sp_config:
- # # point this to the IdP's metadata. You can use either a local file or
- # # (preferably) a URL.
- # metadata:
- # #local: ["saml2/idp.xml"]
- # remote:
- # - url: https://our_idp/metadata.xml
- #
- # # By default, the user has to go to our login page first. If you'd like
- # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
- # # 'service.sp' section:
- # #
- # #service:
- # # sp:
- # # allow_unsolicited: true
- #
- # # The examples below are just used to generate our metadata xml, and you
- # # may well not need them, depending on your setup. Alternatively you
- # # may need a whole lot more detail - see the pysaml2 docs!
- #
- # description: ["My awesome SP", "en"]
- # name: ["Test SP", "en"]
- #
- # organization:
- # name: Example com
- # display_name:
- # - ["Example co", "en"]
- # url: "http://example.com"
- #
- # contact_person:
- # - given_name: Bob
- # sur_name: "the Sysadmin"
- # email_address": ["admin@example.com"]
- # contact_type": technical
+ sp_config:
+ # Point this to the IdP's metadata. You must provide either a local
+ # file via the `local` attribute or (preferably) a URL via the
+ # `remote` attribute.
+ #
+ #metadata:
+ # local: ["saml2/idp.xml"]
+ # remote:
+ # - url: https://our_idp/metadata.xml
+
+ # By default, the user has to go to our login page first. If you'd like
+ # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
+ # 'service.sp' section:
+ #
+ #service:
+ # sp:
+ # allow_unsolicited: true
+
+ # The examples below are just used to generate our metadata xml, and you
+ # may well not need them, depending on your setup. Alternatively you
+ # may need a whole lot more detail - see the pysaml2 docs!
+
+ #description: ["My awesome SP", "en"]
+ #name: ["Test SP", "en"]
+
+ #organization:
+ # name: Example com
+ # display_name:
+ # - ["Example co", "en"]
+ # url: "http://example.com"
+
+ #contact_person:
+ # - given_name: Bob
+ # sur_name: "the Sysadmin"
+ # email_address": ["admin@example.com"]
+ # contact_type": technical
# Instead of putting the config inline as above, you can specify a
# separate pysaml2 configuration file:
@@ -1819,11 +1824,10 @@ saml2_config:
# value: "sales"
-# OpenID Connect integration. The following settings can be used to make Synapse
-# use an OpenID Connect Provider for authentication, instead of its internal
-# password database.
+# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
#
-# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md.
+# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
+# for some example configurations.
#
oidc_config:
# Uncomment the following to enable authorization against an OpenID Connect
@@ -1956,15 +1960,37 @@ oidc_config:
-# Enable CAS for registration and login.
+# Enable Central Authentication Service (CAS) for registration and login.
#
-#cas_config:
-# enabled: true
-# server_url: "https://cas-server.com"
-# service_url: "https://homeserver.domain.com:8448"
-# #displayname_attribute: name
-# #required_attributes:
-# # name: value
+cas_config:
+ # Uncomment the following to enable authorization against a CAS server.
+ # Defaults to false.
+ #
+ #enabled: true
+
+ # The URL of the CAS authorization endpoint.
+ #
+ #server_url: "https://cas-server.com"
+
+ # The public URL of the homeserver.
+ #
+ #service_url: "https://homeserver.domain.com:8448"
+
+ # The attribute of the CAS response to use as the display name.
+ #
+ # If unset, no displayname will be set.
+ #
+ #displayname_attribute: name
+
+ # It is possible to configure Synapse to only allow logins if CAS attributes
+ # match particular values. All of the keys in the mapping below must exist
+ # and the values must match the given value. Alternately if the given value
+ # is None then any value is allowed (the attribute just must exist).
+ # All of the listed attributes must match for the login to be permitted.
+ #
+ #required_attributes:
+ # userGroup: "staff"
+ # department: None
# Additional settings to use with single-sign on systems such as OpenID Connect,
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 35a82c0bfe..3e1df2b035 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.22.0"
+__version__ = "1.22.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 4526c1a67b..2f97e6d258 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -26,14 +26,14 @@ class CasConfig(Config):
def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
- if cas_config:
- self.cas_enabled = cas_config.get("enabled", True)
+ self.cas_enabled = cas_config and cas_config.get("enabled", True)
+
+ if self.cas_enabled:
self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"]
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
- self.cas_required_attributes = cas_config.get("required_attributes", {})
+ self.cas_required_attributes = cas_config.get("required_attributes") or {}
else:
- self.cas_enabled = False
self.cas_server_url = None
self.cas_service_url = None
self.cas_displayname_attribute = None
@@ -41,13 +41,35 @@ class CasConfig(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
- # Enable CAS for registration and login.
+ # Enable Central Authentication Service (CAS) for registration and login.
#
- #cas_config:
- # enabled: true
- # server_url: "https://cas-server.com"
- # service_url: "https://homeserver.domain.com:8448"
- # #displayname_attribute: name
- # #required_attributes:
- # # name: value
+ cas_config:
+ # Uncomment the following to enable authorization against a CAS server.
+ # Defaults to false.
+ #
+ #enabled: true
+
+ # The URL of the CAS authorization endpoint.
+ #
+ #server_url: "https://cas-server.com"
+
+ # The public URL of the homeserver.
+ #
+ #service_url: "https://homeserver.domain.com:8448"
+
+ # The attribute of the CAS response to use as the display name.
+ #
+ # If unset, no displayname will be set.
+ #
+ #displayname_attribute: name
+
+ # It is possible to configure Synapse to only allow logins if CAS attributes
+ # match particular values. All of the keys in the mapping below must exist
+ # and the values must match the given value. Alternately if the given value
+ # is None then any value is allowed (the attribute just must exist).
+ # All of the listed attributes must match for the login to be permitted.
+ #
+ #required_attributes:
+ # userGroup: "staff"
+ # department: None
"""
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 7597fbc864..69d188341c 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -87,11 +87,10 @@ class OIDCConfig(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
- # OpenID Connect integration. The following settings can be used to make Synapse
- # use an OpenID Connect Provider for authentication, instead of its internal
- # password database.
+ # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
#
- # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md.
+ # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
+ # for some example configurations.
#
oidc_config:
# Uncomment the following to enable authorization against an OpenID Connect
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 99aa8b3bf1..778750f43b 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -216,10 +216,8 @@ class SAML2Config(Config):
return """\
## Single sign-on integration ##
- # Enable SAML2 for registration and login. Uses pysaml2.
- #
- # At least one of `sp_config` or `config_path` must be set in this section to
- # enable SAML login.
+ # The following settings can be used to make Synapse use a single sign-on
+ # provider for authentication, instead of its internal password database.
#
# You will probably also want to set the following options to `false` to
# disable the regular login/registration flows:
@@ -228,6 +226,11 @@ class SAML2Config(Config):
#
# You will also want to investigate the settings under the "sso" configuration
# section below.
+
+ # Enable SAML2 for registration and login. Uses pysaml2.
+ #
+ # At least one of `sp_config` or `config_path` must be set in this section to
+ # enable SAML login.
#
# Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
@@ -243,40 +246,42 @@ class SAML2Config(Config):
# so it is not normally necessary to specify them unless you need to
# override them.
#
- #sp_config:
- # # point this to the IdP's metadata. You can use either a local file or
- # # (preferably) a URL.
- # metadata:
- # #local: ["saml2/idp.xml"]
- # remote:
- # - url: https://our_idp/metadata.xml
- #
- # # By default, the user has to go to our login page first. If you'd like
- # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
- # # 'service.sp' section:
- # #
- # #service:
- # # sp:
- # # allow_unsolicited: true
- #
- # # The examples below are just used to generate our metadata xml, and you
- # # may well not need them, depending on your setup. Alternatively you
- # # may need a whole lot more detail - see the pysaml2 docs!
- #
- # description: ["My awesome SP", "en"]
- # name: ["Test SP", "en"]
- #
- # organization:
- # name: Example com
- # display_name:
- # - ["Example co", "en"]
- # url: "http://example.com"
- #
- # contact_person:
- # - given_name: Bob
- # sur_name: "the Sysadmin"
- # email_address": ["admin@example.com"]
- # contact_type": technical
+ sp_config:
+ # Point this to the IdP's metadata. You must provide either a local
+ # file via the `local` attribute or (preferably) a URL via the
+ # `remote` attribute.
+ #
+ #metadata:
+ # local: ["saml2/idp.xml"]
+ # remote:
+ # - url: https://our_idp/metadata.xml
+
+ # By default, the user has to go to our login page first. If you'd like
+ # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
+ # 'service.sp' section:
+ #
+ #service:
+ # sp:
+ # allow_unsolicited: true
+
+ # The examples below are just used to generate our metadata xml, and you
+ # may well not need them, depending on your setup. Alternatively you
+ # may need a whole lot more detail - see the pysaml2 docs!
+
+ #description: ["My awesome SP", "en"]
+ #name: ["Test SP", "en"]
+
+ #organization:
+ # name: Example com
+ # display_name:
+ # - ["Example co", "en"]
+ # url: "http://example.com"
+
+ #contact_person:
+ # - given_name: Bob
+ # sur_name: "the Sysadmin"
+ # email_address": ["admin@example.com"]
+ # contact_type": technical
# Instead of putting the config inline as above, you can specify a
# separate pysaml2 configuration file:
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9b5478b53..82a72dc34f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,8 +15,8 @@
# limitations under the License.
import logging
-from collections import namedtuple
+import attr
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes
@@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import lru_cache
+from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
@@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
dict of user_id -> push_rules
"""
room_id = event.room_id
- rules_for_room = await self._get_rules_for_room(room_id)
+ rules_for_room = self._get_rules_for_room(room_id)
rules_by_user = await rules_for_room.get_rules(event, context)
@@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
return rules_by_user
- @cached()
+ @lru_cache()
def _get_rules_for_room(self, room_id):
"""Get the current RulesForRoom object for the given room id
@@ -275,12 +276,14 @@ class RulesForRoom:
the entire cache for the room.
"""
- def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+ def __init__(
+ self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+ ):
"""
Args:
hs (HomeServer)
room_id (str)
- rules_for_room_cache(Cache): The cache object that caches these
+ rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric)
"""
@@ -489,13 +492,21 @@ class RulesForRoom:
self.state_group = state_group
-class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
- # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
- # which namedtuple does for us (i.e. two _CacheContext are the same if
- # their caches and keys match). This is important in particular to
- # dedupe when we add callbacks to lru cache nodes, otherwise the number
- # of callbacks would grow.
+@attr.attrs(slots=True, frozen=True)
+class _Invalidation:
+ # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
+ # which means that it it is stored on the bulk_get_push_rules cache entry. In order
+ # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+ # we need to ensure that two _Invalidation objects are "equal" if they refer to the
+ # same `cache` and `room_id`.
+ #
+ # attrs provides suitable __hash__ and __eq__ methods, provided we remember to
+ # set `frozen=True`.
+
+ cache = attr.ib(type=LruCache)
+ room_id = attr.ib(type=str)
+
def __call__(self):
- rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
+ rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 5cce7237a0..9cac74ebd8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -305,15 +305,12 @@ class MediaRepository:
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
- if media_info:
- file_id = media_info["filesystem_id"]
- else:
- file_id = random_string(24)
-
- file_info = FileInfo(server_name, file_id)
# If we have an entry in the DB, try and look for it
if media_info:
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
@@ -324,14 +321,34 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
- media_info = await self._download_remote_file(server_name, media_id, file_id)
+ try:
+ media_info = await self._download_remote_file(server_name, media_id,)
+ except SynapseError:
+ raise
+ except Exception as e:
+ # An exception may be because we downloaded media in another
+ # process, so let's check if we magically have the media.
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
+ if not media_info:
+ raise e
+
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
+ # We generate thumbnails even if another process downloaded the media
+ # as a) it's conceivable that the other download request dies before it
+ # generates thumbnails, but mainly b) we want to be sure the thumbnails
+ # have finished being generated before responding to the client,
+ # otherwise they'll request thumbnails and get a 404 if they're not
+ # ready yet.
+ await self._generate_thumbnails(
+ server_name, media_id, file_id, media_info["media_type"]
+ )
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(
- self, server_name: str, media_id: str, file_id: str
- ) -> dict:
+ async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -346,6 +363,8 @@ class MediaRepository:
The media info of the file.
"""
+ file_id = random_string(24)
+
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@@ -401,22 +420,32 @@ class MediaRepository:
await finish()
- media_type = headers[b"Content-Type"][0].decode("ascii")
- upload_name = get_filename_from_headers(headers)
- time_now_ms = self.clock.time_msec()
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ upload_name = get_filename_from_headers(headers)
+ time_now_ms = self.clock.time_msec()
+
+ # Multiple remote media download requests can race (when using
+ # multiple media repos), so this may throw a violation constraint
+ # exception. If it does we'll delete the newly downloaded file from
+ # disk (as we're in the ctx manager).
+ #
+ # However: we've already called `finish()` so we may have also
+ # written to the storage providers. This is preferable to the
+ # alternative where we call `finish()` *after* this, where we could
+ # end up having an entry in the DB but fail to write the files to
+ # the storage providers.
+ await self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
logger.info("Stored remote media in file %r", fname)
- await self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
-
media_info = {
"media_type": media_type,
"media_length": length,
@@ -425,8 +454,6 @@ class MediaRepository:
"filesystem_id": file_id,
}
- await self._generate_thumbnails(server_name, media_id, file_id, media_type)
-
return media_info
def _get_thumbnail_requirements(self, media_type):
@@ -692,42 +719,60 @@ class MediaRepository:
if not t_byte_source:
continue
- try:
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
- url_cache=url_cache,
- )
-
- output_path = await self.media_storage.store_file(
- t_byte_source, file_info
- )
- finally:
- t_byte_source.close()
-
- t_len = os.path.getsize(output_path)
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
- # Write to database
- if server_name:
- await self.store.store_remote_media_thumbnail(
- server_name,
- media_id,
- file_id,
- t_width,
- t_height,
- t_type,
- t_method,
- t_len,
- )
- else:
- await self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ await self.media_storage.write_to_file(t_byte_source, f)
+ await finish()
+ finally:
+ t_byte_source.close()
+
+ t_len = os.path.getsize(fname)
+
+ # Write to database
+ if server_name:
+ # Multiple remote media download requests can race (when
+ # using multiple media repos), so this may throw a violation
+ # constraint exception. If it does we'll delete the newly
+ # generated thumbnail from disk (as we're in the ctx
+ # manager).
+ #
+ # However: we've already called `finish()` so we may have
+ # also written to the storage providers. This is preferable
+ # to the alternative where we call `finish()` *after* this,
+ # where we could end up having an entry in the DB but fail
+ # to write the files to the storage providers.
+ try:
+ await self.store.store_remote_media_thumbnail(
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
+ )
+ except Exception as e:
+ thumbnail_exists = await self.store.get_remote_media_thumbnail(
+ server_name, media_id, t_width, t_height, t_type,
+ )
+ if not thumbnail_exists:
+ raise e
+ else:
+ await self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
return {"width": m_width, "height": m_height}
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a9586fb0b7..268e0c8f50 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -52,6 +52,7 @@ class MediaStorage:
storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
+ self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@@ -70,13 +71,16 @@ class MediaStorage:
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- await defer_to_thread(
- self.hs.get_reactor(), _write_file_synchronously, source, f
- )
+ await self.write_to_file(source, f)
await finish_cb()
return fname
+ async def write_to_file(self, source: IO, output: IO):
+ """Asynchronously write the `source` to `output`.
+ """
+ await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
@@ -112,14 +116,20 @@ class MediaStorage:
finished_called = [False]
- async def finish():
- for provider in self.storage_providers:
- await provider.store_file(path, file_info)
-
- finished_called[0] = True
-
try:
with open(fname, "wb") as f:
+
+ async def finish():
+ # Ensure that all writes have been flushed and close the
+ # file.
+ f.flush()
+ f.close()
+
+ for provider in self.storage_providers:
+ await provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
yield f, fname, finish
except Exception:
try:
@@ -210,7 +220,7 @@ class MediaStorage:
if res:
with res:
consumer = BackgroundFileConsumer(
- open(local_path, "wb"), self.hs.get_reactor()
+ open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5ae263827d..4732685f6e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -530,7 +530,7 @@ class EventsWorkerStore(SQLBaseStore):
self,
context: EventContext,
state_types_to_include: List[EventTypes],
- membership_user_id: Optional[str],
+ membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
Retrieve the stripped state from a room, given an event context to retrieve state
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index daf57675d8..4b2f224718 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
+ async def get_remote_media_thumbnail(
+ self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+ ) -> Optional[Dict[str, Any]]:
+ """Fetch the thumbnail info of given width, height and type.
+ """
+
+ return await self.db_pool.simple_select_one(
+ table="remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": origin,
+ "media_id": media_id,
+ "thumbnail_width": t_width,
+ "thumbnail_height": t_height,
+ "thumbnail_type": t_type,
+ },
+ retcols=(
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ "filesystem_id",
+ ),
+ allow_none=True,
+ desc="get_remote_media_thumbnail",
+ )
+
async def store_remote_media_thumbnail(
self,
origin,
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 5d7fffee66..a924140cdf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,10 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import functools
import inspect
import logging
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
from weakref import WeakValueDictionary
from twisted.internet import defer
@@ -24,6 +37,7 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase:
- def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
+ def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
self.orig = orig
arg_spec = inspect.getfullargspec(orig)
@@ -97,8 +111,107 @@ class _CacheDescriptorBase:
self.add_cache_context = cache_context
+ self.cache_key_builder = get_cache_key_builder(
+ self.arg_names, self.arg_defaults
+ )
+
+
+class _LruCachedFunction(Generic[F]):
+ cache = None # type: LruCache[CacheKey, Any]
+ __call__ = None # type: F
+
+
+def lru_cache(
+ max_entries: int = 1000, cache_context: bool = False,
+) -> Callable[[F], _LruCachedFunction[F]]:
+ """A method decorator that applies a memoizing cache around the function.
+
+ This is more-or-less a drop-in equivalent to functools.lru_cache, although note
+ that the signature is slightly different.
+
+ The main differences with functools.lru_cache are:
+ (a) the size of the cache can be controlled via the cache_factor mechanism
+ (b) the wrapped function can request a "cache_context" which provides a
+ callback mechanism to indicate that the result is no longer valid
+ (c) prometheus metrics are exposed automatically.
+
+ The function should take zero or more arguments, which are used as the key for the
+ cache. Single-argument functions use that argument as the cache key; otherwise the
+ arguments are built into a tuple.
+
+ Cached functions can be "chained" (i.e. a cached function can call other cached
+ functions and get appropriately invalidated when they called caches are
+ invalidated) by adding a special "cache_context" argument to the function
+ and passing that as a kwarg to all caches called. For example:
+
+ @lru_cache(cache_context=True)
+ def foo(self, key, cache_context):
+ r1 = self.bar1(key, on_invalidate=cache_context.invalidate)
+ r2 = self.bar2(key, on_invalidate=cache_context.invalidate)
+ return r1 + r2
+
+ The wrapped function also has a 'cache' property which offers direct access to the
+ underlying LruCache.
+ """
+
+ def func(orig: F) -> _LruCachedFunction[F]:
+ desc = LruCacheDescriptor(
+ orig, max_entries=max_entries, cache_context=cache_context,
+ )
+ return cast(_LruCachedFunction[F], desc)
+
+ return func
+
+
+class LruCacheDescriptor(_CacheDescriptorBase):
+ """Helper for @lru_cache"""
+
+ class _Sentinel(enum.Enum):
+ sentinel = object()
+
+ def __init__(
+ self, orig, max_entries: int = 1000, cache_context: bool = False,
+ ):
+ super().__init__(orig, num_args=None, cache_context=cache_context)
+ self.max_entries = max_entries
+
+ def __get__(self, obj, owner):
+ cache = LruCache(
+ cache_name=self.orig.__name__, max_size=self.max_entries,
+ ) # type: LruCache[CacheKey, Any]
+
+ get_cache_key = self.cache_key_builder
+ sentinel = LruCacheDescriptor._Sentinel.sentinel
+
+ @functools.wraps(self.orig)
+ def _wrapped(*args, **kwargs):
+ invalidate_callback = kwargs.pop("on_invalidate", None)
+ callbacks = (invalidate_callback,) if invalidate_callback else ()
+
+ cache_key = get_cache_key(args, kwargs)
-class CacheDescriptor(_CacheDescriptorBase):
+ ret = cache.get(cache_key, default=sentinel, callbacks=callbacks)
+ if ret != sentinel:
+ return ret
+
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
+
+ ret2 = self.orig(obj, *args, **kwargs)
+ cache.set(cache_key, ret2, callbacks=callbacks)
+
+ return ret2
+
+ wrapped = cast(_CachedFunction, _wrapped)
+ wrapped.cache = cache
+ obj.__dict__[self.orig.__name__] = wrapped
+
+ return wrapped
+
+
+class DeferredCacheDescriptor(_CacheDescriptorBase):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
@@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_context=False,
iterable=False,
):
-
super().__init__(orig, num_args=num_args, cache_context=cache_context)
self.max_entries = max_entries
@@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
) # type: DeferredCache[CacheKey, Any]
- def get_cache_key_gen(args, kwargs):
- """Given some args/kwargs return a generator that resolves into
- the cache_key.
-
- We loop through each arg name, looking up if its in the `kwargs`,
- otherwise using the next argument in `args`. If there are no more
- args then we try looking the arg name up in the defaults
- """
- pos = 0
- for nm in self.arg_names:
- if nm in kwargs:
- yield kwargs[nm]
- elif pos < len(args):
- yield args[pos]
- pos += 1
- else:
- yield self.arg_defaults[nm]
-
- # By default our cache key is a tuple, but if there is only one item
- # then don't bother wrapping in a tuple. This is to save memory.
- if self.num_args == 1:
- nm = self.arg_names[0]
-
- def get_cache_key(args, kwargs):
- if nm in kwargs:
- return kwargs[nm]
- elif len(args):
- return args[0]
- else:
- return self.arg_defaults[nm]
-
- else:
-
- def get_cache_key(args, kwargs):
- return tuple(get_cache_key_gen(args, kwargs))
+ get_cache_key = self.cache_key_builder
@functools.wraps(self.orig)
def _wrapped(*args, **kwargs):
@@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase):
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
@@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return wrapped
-class CacheListDescriptor(_CacheDescriptorBase):
+class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
@@ -382,11 +459,13 @@ class _CacheContext:
on a lower level.
"""
+ Cache = Union[DeferredCache, LruCache]
+
_cache_context_objects = (
WeakValueDictionary()
- ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
+ ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
- def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
+ def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
self._cache = cache
self._cache_key = cache_key
@@ -396,8 +475,8 @@ class _CacheContext:
@classmethod
def get_instance(
- cls, cache, cache_key
- ): # type: (DeferredCache, CacheKey) -> _CacheContext
+ cls, cache: "_CacheContext.Cache", cache_key: CacheKey
+ ) -> "_CacheContext":
"""Returns an instance constructed with the given arguments.
A new instance is only created if none already exists.
@@ -418,7 +497,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
- func = lambda orig: CacheDescriptor(
+ func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
@@ -460,7 +539,7 @@ def cachedList(
def batch_do_something(self, first_arg, second_args):
...
"""
- func = lambda orig: CacheListDescriptor(
+ func = lambda orig: DeferredCacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
@@ -468,3 +547,65 @@ def cachedList(
)
return cast(Callable[[F], _CachedFunction[F]], func)
+
+
+def get_cache_key_builder(
+ param_names: Sequence[str], param_defaults: Mapping[str, Any]
+) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
+ """Construct a function which will build cache keys suitable for a cached function
+
+ Args:
+ param_names: list of formal parameter names for the cached function
+ param_defaults: a mapping from parameter name to default value for that param
+
+ Returns:
+ A function which will take an (args, kwargs) pair and return a cache key
+ """
+
+ # By default our cache key is a tuple, but if there is only one item
+ # then don't bother wrapping in a tuple. This is to save memory.
+
+ if len(param_names) == 1:
+ nm = param_names[0]
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ if nm in kwargs:
+ return kwargs[nm]
+ elif len(args):
+ return args[0]
+ else:
+ return param_defaults[nm]
+
+ else:
+
+ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
+ return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
+
+ return get_cache_key
+
+
+def _get_cache_key_gen(
+ param_names: Iterable[str],
+ param_defaults: Mapping[str, Any],
+ args: Sequence[Any],
+ kwargs: Mapping[str, Any],
+) -> Iterable[Any]:
+ """Given some args/kwargs return a generator that resolves into
+ the cache_key.
+
+ This is essentially the same operation as `inspect.getcallargs`, but optimised so
+ that we don't need to inspect the target function for each call.
+ """
+
+ # We loop through each arg name, looking up if its in the `kwargs`,
+ # otherwise using the next argument in `args`. If there are no more
+ # args then we try looking the arg name up in the defaults.
+ pos = 0
+ for nm in param_names:
+ if nm in kwargs:
+ yield kwargs[nm]
+ elif pos < len(args):
+ yield args[pos]
+ pos += 1
+ else:
+ yield param_defaults[nm]
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
new file mode 100644
index 0000000000..77c261dbf7
--- /dev/null
+++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,277 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from binascii import unhexlify
+from typing import Tuple
+
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+from twisted.web.server import Request
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
+from synapse.server import HomeServer
+
+from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, FakeTransport
+
+logger = logging.getLogger(__name__)
+
+test_server_connection_factory = None
+
+
+class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks running multiple media repos work correctly.
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("user", "pass")
+ self.access_token = self.login("user", "pass")
+
+ self.reactor.lookups["example.com"] = "127.0.0.2"
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
+ return conf
+
+ def _get_media_req(
+ self, hs: HomeServer, target: str, media_id: str
+ ) -> Tuple[FakeChannel, Request]:
+ """Request some remote media from the given HS by calling the download
+ API.
+
+ This then triggers an outbound request from the HS to the target.
+
+ Returns:
+ The channel for the *client* request and the *outbound* request for
+ the media which the caller should respond to.
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/{}/{}".format(target, media_id),
+ shorthand=False,
+ access_token=self.access_token,
+ )
+ request.render(hs.get_media_repository_resource().children[b"download"])
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+
+ # build the test server
+ server_tls_protocol = _build_test_server(get_connection_factory())
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_tls_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_tls_protocol)
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_server = server_tls_protocol.wrappedProtocol
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(
+ request.path,
+ "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"),
+ )
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
+ )
+
+ return channel, request
+
+ def test_basic(self):
+ """Test basic fetching of remote media from a single worker.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+
+ channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
+
+ request.setResponseCode(200)
+ request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request.write(b"Hello!")
+ request.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.result["body"], b"Hello!")
+
+ def test_download_simple_file_race(self):
+ """Test that fetching remote media from two different processes at the
+ same time works.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_media()
+
+ # Make two requests without responding to the outbound media requests.
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
+
+ # Respond to the first outbound media request and check that the client
+ # request is successful
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request1.write(b"Hello!")
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], b"Hello!")
+
+ # Now respond to the second with the same content.
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"])
+ request2.write(b"Hello!")
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], b"Hello!")
+
+ # We expect only one new file to have been persisted.
+ self.assertEqual(start_count + 1, self._count_remote_media())
+
+ def test_download_image_race(self):
+ """Test that fetching remote *images* from two different processes at
+ the same time works.
+
+ This checks that races generating thumbnails are handled correctly.
+ """
+ hs1 = self.make_worker_hs("synapse.app.generic_worker")
+ hs2 = self.make_worker_hs("synapse.app.generic_worker")
+
+ start_count = self._count_remote_thumbnails()
+
+ channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
+ channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
+
+ png_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ request1.setResponseCode(200)
+ request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request1.write(png_data)
+ request1.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel1.code, 200, channel1.result["body"])
+ self.assertEqual(channel1.result["body"], png_data)
+
+ request2.setResponseCode(200)
+ request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
+ request2.write(png_data)
+ request2.finish()
+
+ self.pump(0.1)
+
+ self.assertEqual(channel2.code, 200, channel2.result["body"])
+ self.assertEqual(channel2.result["body"], png_data)
+
+ # We expect only three new thumbnails to have been persisted.
+ self.assertEqual(start_count + 3, self._count_remote_thumbnails())
+
+ def _count_remote_media(self) -> int:
+ """Count the number of files in our remote media directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_content"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+ def _count_remote_thumbnails(self) -> int:
+ """Count the number of files in our remote thumbnails directory.
+ """
+ path = os.path.join(
+ self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
+ )
+ return sum(len(files) for _, _, files in os.walk(path))
+
+
+def get_connection_factory():
+ # this needs to happen once, but not until we are ready to run the first test
+ global test_server_connection_factory
+ if test_server_connection_factory is None:
+ test_server_connection_factory = TestServerTLSConnectionFactory(
+ sanlist=[b"DNS:example.com"]
+ )
+ return test_server_connection_factory
+
+
+def _build_test_server(connection_creator):
+ """Construct a test server
+
+ This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+
+ Args:
+ connection_creator (IOpenSSLServerConnectionCreator): thing to build
+ SSL connections
+ sanlist (list[bytes]): list of the SAN entries for the cert returned
+ by the server
+
+ Returns:
+ TLSMemoryBIOProtocol
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_factory = TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=server_factory
+ )
+
+ return server_tls_factory.buildProtocol(None)
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/server.py b/tests/server.py
index b97003fa5a..3dd2cfc072 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -46,7 +46,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
- result = attr.ib(default=attr.Factory(dict))
+ result = attr.ib(type=dict, default=attr.Factory(dict))
_producer = None
@property
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@
"""
Utilities for running the unit tests
"""
+import sys
+import warnings
from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
TV = TypeVar("TV")
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
future = Future() # type: ignore
future.set_result(result)
return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+ """
+ Convert warnings from a non-awaited coroutines into errors.
+ """
+ warnings.simplefilter("error", RuntimeWarning)
+
+ # unraisablehook was added in Python 3.8.
+ if not hasattr(sys, "unraisablehook"):
+ return lambda: None
+
+ # State shared between unraisablehook and check_for_unraisable_exceptions.
+ unraisable_exceptions = []
+ orig_unraisablehook = sys.unraisablehook # type: ignore
+
+ def unraisablehook(unraisable):
+ unraisable_exceptions.append(unraisable.exc_value)
+
+ def cleanup():
+ """
+ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+ """
+ sys.unraisablehook = orig_unraisablehook # type: ignore
+ if unraisable_exceptions:
+ raise unraisable_exceptions.pop()
+
+ sys.unraisablehook = unraisablehook # type: ignore
+
+ return cleanup
diff --git a/tests/unittest.py b/tests/unittest.py
index 257f465897..08cf9b10c5 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -54,7 +54,7 @@ from tests.server import (
render,
setup_test_homeserver,
)
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -119,6 +119,10 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
+ # Trial messes with the warnings configuration, thus this has to be
+ # done in the context of an individual TestCase.
+ self.addCleanup(setup_awaitable_errors())
+
return orig()
@around(self)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2ad08f541b..cf1e3203a4 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -29,13 +29,46 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, lru_cache
from tests import unittest
+from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
+class LruCacheDecoratorTestCase(unittest.TestCase):
+ def test_base(self):
+ class Cls:
+ def __init__(self):
+ self.mock = mock.Mock()
+
+ @lru_cache()
+ def fn(self, arg1, arg2):
+ return self.mock(arg1, arg2)
+
+ obj = Cls()
+ obj.mock.return_value = "fish"
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ obj.mock.assert_called_once_with(1, 2)
+ obj.mock.reset_mock()
+
+ # a call with different params should call the mock again
+ obj.mock.return_value = "chips"
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_called_once_with(1, 3)
+ obj.mock.reset_mock()
+
+ # the two values should now be cached
+ r = obj.fn(1, 2)
+ self.assertEqual(r, "fish")
+ r = obj.fn(1, 3)
+ self.assertEqual(r, "chips")
+ obj.mock.assert_not_called()
+
+
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
@@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_invalidate_cascade(self):
+ """Invalidations should cascade up through cache contexts"""
+
+ class Cls:
+ @cached(cache_context=True)
+ async def func1(self, key, cache_context):
+ return await self.func2(key, on_invalidate=cache_context.invalidate)
+
+ @cached(cache_context=True)
+ async def func2(self, key, cache_context):
+ return self.func3(key, on_invalidate=cache_context.invalidate)
+
+ @lru_cache(cache_context=True)
+ def func3(self, key, cache_context):
+ self.invalidate = cache_context.invalidate
+ return 42
+
+ obj = Cls()
+
+ top_invalidate = mock.Mock()
+ r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
+ self.assertEqual(r, 42)
+ obj.invalidate()
+ top_invalidate.assert_called_once()
+
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
|