diff --git a/.gitignore b/.gitignore
index af36c00cfa..9bb5bdd647 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,6 +21,7 @@ _trial_temp*/
/.python-version
/*.signing.key
/env/
+/.venv*/
/homeserver*.yaml
/logs
/media_store/
diff --git a/changelog.d/8504.bugfix b/changelog.d/8504.bugfix
new file mode 100644
index 0000000000..2bd0dbb8b4
--- /dev/null
+++ b/changelog.d/8504.bugfix
@@ -0,0 +1 @@
+Expose the `uk.half-shot.msc2778.login.application_service` to clients from the login API. This feature was added in v1.21.0, but was not exposed as a potential login flow.
diff --git a/changelog.d/8544.feature b/changelog.d/8544.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8544.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8545.bugfix b/changelog.d/8545.bugfix
new file mode 100644
index 0000000000..64ba307df0
--- /dev/null
+++ b/changelog.d/8545.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug where email notifications for encrypted messages were blank.
diff --git a/changelog.d/8561.misc b/changelog.d/8561.misc
new file mode 100644
index 0000000000..a40dedfa8e
--- /dev/null
+++ b/changelog.d/8561.misc
@@ -0,0 +1 @@
+Move metric registration code down into `LruCache`.
diff --git a/changelog.d/8562.misc b/changelog.d/8562.misc
new file mode 100644
index 0000000000..ebdbddb500
--- /dev/null
+++ b/changelog.d/8562.misc
@@ -0,0 +1 @@
+Add type annotations for `LruCache`.
diff --git a/changelog.d/8563.misc b/changelog.d/8563.misc
new file mode 100644
index 0000000000..eeba8e5fee
--- /dev/null
+++ b/changelog.d/8563.misc
@@ -0,0 +1 @@
+Replace `DeferredCache` with the lighter-weight `LruCache` where possible.
diff --git a/changelog.d/8564.feature b/changelog.d/8564.feature
new file mode 100644
index 0000000000..45342e66ad
--- /dev/null
+++ b/changelog.d/8564.feature
@@ -0,0 +1 @@
+Support modifying event content in `ThirdPartyRules` modules.
diff --git a/changelog.d/8566.misc b/changelog.d/8566.misc
new file mode 100644
index 0000000000..453cf48ffa
--- /dev/null
+++ b/changelog.d/8566.misc
@@ -0,0 +1 @@
+Add virtualenv-generated folders to `.gitignore`.
\ No newline at end of file
diff --git a/changelog.d/8567.bugfix b/changelog.d/8567.bugfix
new file mode 100644
index 0000000000..4d835df6fd
--- /dev/null
+++ b/changelog.d/8567.bugfix
@@ -0,0 +1 @@
+Fix increase in the number of `There was no active span...` errors logged when using OpenTracing.
diff --git a/changelog.d/8568.misc b/changelog.d/8568.misc
new file mode 100644
index 0000000000..0ed7db92d3
--- /dev/null
+++ b/changelog.d/8568.misc
@@ -0,0 +1 @@
+Add `get_immediate` method to `DeferredCache`.
diff --git a/changelog.d/8569.misc b/changelog.d/8569.misc
new file mode 100644
index 0000000000..3b6e0625e5
--- /dev/null
+++ b/changelog.d/8569.misc
@@ -0,0 +1 @@
+Fix mypy not properly checking across the codebase, additionally, fix a typing assertion error in `handlers/auth.py`.
\ No newline at end of file
diff --git a/changelog.d/8571.misc b/changelog.d/8571.misc
new file mode 100644
index 0000000000..f6a65057e0
--- /dev/null
+++ b/changelog.d/8571.misc
@@ -0,0 +1 @@
+Fix `synmark` benchmark runner.
diff --git a/changelog.d/8572.misc b/changelog.d/8572.misc
new file mode 100644
index 0000000000..ea2a6d340d
--- /dev/null
+++ b/changelog.d/8572.misc
@@ -0,0 +1 @@
+Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.
diff --git a/changelog.d/8577.misc b/changelog.d/8577.misc
new file mode 100644
index 0000000000..75fe563a02
--- /dev/null
+++ b/changelog.d/8577.misc
@@ -0,0 +1 @@
+Adjust a protocol-type definition to fit `sqlite3` assertions.
\ No newline at end of file
diff --git a/changelog.d/8578.misc b/changelog.d/8578.misc
new file mode 100644
index 0000000000..e93462255b
--- /dev/null
+++ b/changelog.d/8578.misc
@@ -0,0 +1 @@
+Support macOS on the `synmark` benchmark runner.
diff --git a/changelog.d/8583.misc b/changelog.d/8583.misc
new file mode 100644
index 0000000000..d24973f09a
--- /dev/null
+++ b/changelog.d/8583.misc
@@ -0,0 +1 @@
+Update `mypy` static type checker to 0.790.
\ No newline at end of file
diff --git a/changelog.d/8585.bugfix b/changelog.d/8585.bugfix
new file mode 100644
index 0000000000..e97e6ac1d8
--- /dev/null
+++ b/changelog.d/8585.bugfix
@@ -0,0 +1 @@
+Fix a bug that prevented errors encountered during execution of the `synapse_port_db` from being correctly printed.
\ No newline at end of file
diff --git a/changelog.d/8587.misc b/changelog.d/8587.misc
new file mode 100644
index 0000000000..9e56551a34
--- /dev/null
+++ b/changelog.d/8587.misc
@@ -0,0 +1 @@
+Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting.
diff --git a/changelog.d/8589.removal b/changelog.d/8589.removal
new file mode 100644
index 0000000000..b80f29d6bb
--- /dev/null
+++ b/changelog.d/8589.removal
@@ -0,0 +1 @@
+Drop unused `device_max_stream_id` table.
diff --git a/changelog.d/8590.misc b/changelog.d/8590.misc
new file mode 100644
index 0000000000..4abcccb326
--- /dev/null
+++ b/changelog.d/8590.misc
@@ -0,0 +1 @@
+Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.
diff --git a/changelog.d/8591.misc b/changelog.d/8591.misc
new file mode 100644
index 0000000000..8f16bc3e7e
--- /dev/null
+++ b/changelog.d/8591.misc
@@ -0,0 +1 @@
+ Move metric registration code down into `LruCache`.
diff --git a/changelog.d/8592.misc b/changelog.d/8592.misc
new file mode 100644
index 0000000000..099e8fb7bb
--- /dev/null
+++ b/changelog.d/8592.misc
@@ -0,0 +1 @@
+Remove extraneous unittest logging decorators from unit tests.
\ No newline at end of file
diff --git a/changelog.d/8593.misc b/changelog.d/8593.misc
new file mode 100644
index 0000000000..d266ba19a4
--- /dev/null
+++ b/changelog.d/8593.misc
@@ -0,0 +1 @@
+Minor optimisations in caching code.
diff --git a/changelog.d/8594.misc b/changelog.d/8594.misc
new file mode 100644
index 0000000000..d266ba19a4
--- /dev/null
+++ b/changelog.d/8594.misc
@@ -0,0 +1 @@
+Minor optimisations in caching code.
diff --git a/changelog.d/8599.feature b/changelog.d/8599.feature
new file mode 100644
index 0000000000..542993110b
--- /dev/null
+++ b/changelog.d/8599.feature
@@ -0,0 +1 @@
+Allow running background tasks in a separate worker process.
diff --git a/changelog.d/8600.misc b/changelog.d/8600.misc
new file mode 100644
index 0000000000..a5a922e641
--- /dev/null
+++ b/changelog.d/8600.misc
@@ -0,0 +1 @@
+Update `mypy` static type checker to 0.790.
diff --git a/changelog.d/8606.feature b/changelog.d/8606.feature
new file mode 100644
index 0000000000..fad723c108
--- /dev/null
+++ b/changelog.d/8606.feature
@@ -0,0 +1 @@
+Limit appservice transactions to 100 persistent and 100 ephemeral events.
diff --git a/changelog.d/8609.misc b/changelog.d/8609.misc
new file mode 100644
index 0000000000..5e3f3c1993
--- /dev/null
+++ b/changelog.d/8609.misc
@@ -0,0 +1 @@
+Add type hints to profile and base handler.
diff --git a/mypy.ini b/mypy.ini
index b5db54ee3b..5e9f7b1259 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -15,8 +15,9 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
- synapse/handlers/appservice.py,
+ synapse/handlers/_base.py,
synapse/handlers/account_data.py,
+ synapse/handlers/appservice.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
@@ -32,6 +33,7 @@ files =
synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
+ synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 2d0b59ab53..6c7664ad4a 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -22,6 +22,7 @@ import logging
import sys
import time
import traceback
+from typing import Optional
import yaml
@@ -152,7 +153,7 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes.
-end_error = None
+end_error = None # type: Optional[str]
# The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run
@@ -635,7 +636,7 @@ class Porter(object):
self.progress.done()
except Exception as e:
global end_error_exec_info
- end_error = e
+ end_error = str(e)
end_error_exec_info = sys.exc_info()
logger.exception("")
finally:
diff --git a/setup.py b/setup.py
index 08843fe2a3..2f4a3170d2 100755
--- a/setup.py
+++ b/setup.py
@@ -102,6 +102,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8",
]
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
+
# Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests.
# Tests assume that all optional dependencies are installed.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1071a0576e..bff87fabde 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
from synapse.types import StateMap, UserID
-from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -70,8 +69,9 @@ class Auth:
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.token_cache = LruCache(10000)
- register_cache("cache", "token_cache", self.token_cache)
+ self.token_cache = LruCache(
+ 10000, "token_cache"
+ ) # type: LruCache[str, Tuple[str, bool]]
self._auth_blocking = AuthBlocking(self.hs)
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index ad3c408519..58291afc22 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -60,6 +60,13 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+# Maximum number of events to provide in an AS transaction.
+MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
+
+# Maximum number of ephemeral events to provide in an AS transaction.
+MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
+
+
class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
@@ -136,10 +143,17 @@ class _ServiceQueuer:
self.requests_in_flight.add(service.id)
try:
while True:
- events = self.queued_events.pop(service.id, [])
- ephemeral = self.queued_ephemeral.pop(service.id, [])
+ all_events = self.queued_events.get(service.id, [])
+ events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+ del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+
+ all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
+ ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+ del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+
if not events and not ephemeral:
return
+
try:
await self.txn_ctrl.send(service, events, ephemeral)
except Exception:
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index df4f950fec..07df258e6e 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -98,7 +98,7 @@ class EventBuilder:
return self._state_key is not None
async def build(
- self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
+ self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
) -> EventBase:
"""Transform into a fully signed and hashed event
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 0206320e96..bd8e71ae56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
import synapse.state
import synapse.storage
@@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -30,11 +34,7 @@ class BaseHandler:
Common base class for the event handlers.
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -56,7 +56,7 @@ class BaseHandler:
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
- )
+ ) # type: Optional[Ratelimiter]
else:
self.admin_redaction_ratelimiter = None
@@ -127,15 +127,15 @@ class BaseHandler:
if guest_access != "can_join":
if context:
current_state_ids = await context.get_current_state_ids()
- current_state = await self.store.get_events(
+ current_state_dict = await self.store.get_events(
list(current_state_ids.values())
)
+ current_state = list(current_state_dict.values())
else:
- current_state = await self.state_handler.get_current_state(
+ current_state_map = await self.state_handler.get_current_state(
event.room_id
)
-
- current_state = list(current_state.values())
+ current_state = list(current_state_map.values())
logger.info("maybe_kick_guest_users %r", current_state)
await self.kick_guest_users(current_state)
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index f33044e97a..fd4f762f33 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -22,7 +22,7 @@ from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
@@ -63,16 +63,10 @@ class AccountValidityHandler:
self._raw_from = email.utils.parseaddr(self._from_string)[1]
# Check the renewal emails to send and send them every 30min.
- def send_emails():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "send_renewals", self._send_renewal_emails
- )
-
if hs.config.run_background_tasks:
- self.clock.looping_call(send_emails, 30 * 60 * 1000)
+ self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
+ @wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1d1ddc2245..8619fbb982 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1122,20 +1122,22 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
- def _do_validate_hash():
+ def _do_validate_hash(checked_hash: bytes):
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
- stored_hash,
+ checked_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii")
- return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+ return await defer_to_thread(
+ self.hs.get_reactor(), _do_validate_hash, stored_hash
+ )
else:
return False
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 98075f48d2..cb11754bf8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
+ # The member_event_id will always be available if membership is set
+ # to leave.
+ assert member_event_id
+
result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
@@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
- membership: Membership,
+ membership: str,
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
@@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
- membership: Membership,
+ membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e37bca3d12..e4f2115617 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1364,7 +1364,12 @@ class EventCreationHandler:
for k, v in original_event.internal_metadata.get_dict().items():
setattr(builder.internal_metadata, k, v)
- event = await builder.build(prev_event_ids=original_event.prev_event_ids())
+ # the event type hasn't changed, so there's no point in re-calculating the
+ # auth events.
+ event = await builder.build(
+ prev_event_ids=original_event.prev_event_ids(),
+ auth_event_ids=original_event.auth_event_ids(),
+ )
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index b784938755..92700b589c 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import random
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import (
AuthError,
@@ -24,11 +24,20 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester, get_domain_from_id
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.types import (
+ JsonDict,
+ Requester,
+ UserID,
+ create_requester,
+ get_domain_from_id,
+)
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256
@@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation = hs.get_federation_client()
@@ -57,10 +66,10 @@ class ProfileHandler(BaseHandler):
if hs.config.run_background_tasks:
self.clock.looping_call(
- self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+ self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
- async def get_profile(self, user_id):
+ async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
@@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- async def get_profile_from_cache(self, user_id):
+ async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
@@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}
- async def get_displayname(self, target_user):
+ async def get_displayname(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
@@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
return result["displayname"]
async def set_displayname(
- self, target_user, requester, new_displayname, by_admin=False
- ):
+ self,
+ target_user: UserID,
+ requester: Requester,
+ new_displayname: str,
+ by_admin: bool = False,
+ ) -> None:
"""Set the displayname of a user
Args:
- target_user (UserID): the user whose displayname is to be changed.
- requester (Requester): The user attempting to make this change.
- new_displayname (str): The displayname to give this user.
- by_admin (bool): Whether this change was made by an administrator.
+ target_user: the user whose displayname is to be changed.
+ requester: The user attempting to make this change.
+ new_displayname: The displayname to give this user.
+ by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
+ displayname_to_set = new_displayname # type: Optional[str]
if new_displayname == "":
- new_displayname = None
+ displayname_to_set = None
# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
@@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ await self.store.set_profile_displayname(
+ target_user.localpart, displayname_to_set
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
- async def get_avatar_url(self, target_user):
+ async def get_avatar_url(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(
@@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
return result["avatar_url"]
async def set_avatar_url(
- self, target_user, requester, new_avatar_url, by_admin=False
+ self,
+ target_user: UserID,
+ requester: Requester,
+ new_avatar_url: str,
+ by_admin: bool = False,
):
"""Set a new avatar URL for a user.
Args:
- target_user (UserID): the user whose avatar URL is to be changed.
- requester (Requester): The user attempting to make this change.
- new_avatar_url (str): The avatar URL to give this user.
- by_admin (bool): Whether this change was made by an administrator.
+ target_user: the user whose avatar URL is to be changed.
+ requester: The user attempting to make this change.
+ new_avatar_url: The avatar URL to give this user.
+ by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
- async def on_profile_query(self, args):
+ async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
return response
- async def _update_join_states(self, requester, target_user):
+ async def _update_join_states(
+ self, requester: Requester, target_user: UserID
+ ) -> None:
if not self.hs.is_mine(target_user):
return
@@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e)
)
- async def check_profile_query_allowed(self, target_user, requester=None):
+ async def check_profile_query_allowed(
+ self, target_user: UserID, requester: Optional[UserID] = None
+ ) -> None:
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
Args:
- target_user (UserID): The owner of the queried profile.
- requester (None|UserID): The user querying for the profile.
+ target_user: The owner of the queried profile.
+ requester: The user querying for the profile.
Raises:
SynapseError(403): The two users share no room, or ne user couldn't
@@ -370,11 +394,7 @@ class ProfileHandler(BaseHandler):
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
raise
- def _start_update_remote_profile_cache(self):
- return run_as_background_process(
- "Update remote profile", self._update_remote_profile_cache
- )
-
+ @wrap_as_background_process("Update remote profile")
async def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
new file mode 100644
index 0000000000..0caf325916
--- /dev/null
+++ b/synapse/logging/_remote.py
@@ -0,0 +1,225 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import traceback
+from collections import deque
+from ipaddress import IPv4Address, IPv6Address, ip_address
+from math import floor
+from typing import Callable, Optional
+
+import attr
+from zope.interface import implementer
+
+from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
+from twisted.internet.endpoints import (
+ HostnameEndpoint,
+ TCP4ClientEndpoint,
+ TCP6ClientEndpoint,
+)
+from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.protocol import Factory, Protocol
+from twisted.logger import ILogObserver, Logger, LogLevel
+
+
+@attr.s
+@implementer(IPushProducer)
+class LogProducer:
+ """
+ An IPushProducer that writes logs from its buffer to its transport when it
+ is resumed.
+
+ Args:
+ buffer: Log buffer to read logs from.
+ transport: Transport to write to.
+ format_event: A callable to format the log entry to a string.
+ """
+
+ transport = attr.ib(type=ITransport)
+ format_event = attr.ib(type=Callable[[dict], str])
+ _buffer = attr.ib(type=deque)
+ _paused = attr.ib(default=False, type=bool, init=False)
+
+ def pauseProducing(self):
+ self._paused = True
+
+ def stopProducing(self):
+ self._paused = True
+ self._buffer = deque()
+
+ def resumeProducing(self):
+ self._paused = False
+
+ while self._paused is False and (self._buffer and self.transport.connected):
+ try:
+ # Request the next event and format it.
+ event = self._buffer.popleft()
+ msg = self.format_event(event)
+
+ # Send it as a new line over the transport.
+ self.transport.write(msg.encode("utf8"))
+ except Exception:
+ # Something has gone wrong writing to the transport -- log it
+ # and break out of the while.
+ traceback.print_exc(file=sys.__stderr__)
+ break
+
+
+@attr.s
+@implementer(ILogObserver)
+class TCPLogObserver:
+ """
+ An IObserver that writes JSON logs to a TCP target.
+
+ Args:
+ hs (HomeServer): The homeserver that is being logged for.
+ host: The host of the logging target.
+ port: The logging target's port.
+ format_event: A callable to format the log entry to a string.
+ maximum_buffer: The maximum buffer size.
+ """
+
+ hs = attr.ib()
+ host = attr.ib(type=str)
+ port = attr.ib(type=int)
+ format_event = attr.ib(type=Callable[[dict], str])
+ maximum_buffer = attr.ib(type=int)
+ _buffer = attr.ib(default=attr.Factory(deque), type=deque)
+ _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
+ _logger = attr.ib(default=attr.Factory(Logger))
+ _producer = attr.ib(default=None, type=Optional[LogProducer])
+
+ def start(self) -> None:
+
+ # Connect without DNS lookups if it's a direct IP.
+ try:
+ ip = ip_address(self.host)
+ if isinstance(ip, IPv4Address):
+ endpoint = TCP4ClientEndpoint(
+ self.hs.get_reactor(), self.host, self.port
+ )
+ elif isinstance(ip, IPv6Address):
+ endpoint = TCP6ClientEndpoint(
+ self.hs.get_reactor(), self.host, self.port
+ )
+ else:
+ raise ValueError("Unknown IP address provided: %s" % (self.host,))
+ except ValueError:
+ endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+
+ factory = Factory.forProtocol(Protocol)
+ self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+ self._service.startService()
+ self._connect()
+
+ def stop(self):
+ self._service.stopService()
+
+ def _connect(self) -> None:
+ """
+ Triggers an attempt to connect then write to the remote if not already writing.
+ """
+ if self._connection_waiter:
+ return
+
+ self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+ @self._connection_waiter.addErrback
+ def fail(r):
+ r.printTraceback(file=sys.__stderr__)
+ self._connection_waiter = None
+ self._connect()
+
+ @self._connection_waiter.addCallback
+ def writer(r):
+ # We have a connection. If we already have a producer, and its
+ # transport is the same, just trigger a resumeProducing.
+ if self._producer and r.transport is self._producer.transport:
+ self._producer.resumeProducing()
+ self._connection_waiter = None
+ return
+
+ # If the producer is still producing, stop it.
+ if self._producer:
+ self._producer.stopProducing()
+
+ # Make a new producer and start it.
+ self._producer = LogProducer(
+ buffer=self._buffer,
+ transport=r.transport,
+ format_event=self.format_event,
+ )
+ r.transport.registerProducer(self._producer, True)
+ self._producer.resumeProducing()
+ self._connection_waiter = None
+
+ def _handle_pressure(self) -> None:
+ """
+ Handle backpressure by shedding events.
+
+ The buffer will, in this order, until the buffer is below the maximum:
+ - Shed DEBUG events
+ - Shed INFO events
+ - Shed the middle 50% of the events.
+ """
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Strip out DEBUGs
+ self._buffer = deque(
+ filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
+ )
+
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Strip out INFOs
+ self._buffer = deque(
+ filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
+ )
+
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Cut the middle entries out
+ buffer_split = floor(self.maximum_buffer / 2)
+
+ old_buffer = self._buffer
+ self._buffer = deque()
+
+ for i in range(buffer_split):
+ self._buffer.append(old_buffer.popleft())
+
+ end_buffer = []
+ for i in range(buffer_split):
+ end_buffer.append(old_buffer.pop())
+
+ self._buffer.extend(reversed(end_buffer))
+
+ def __call__(self, event: dict) -> None:
+ self._buffer.append(event)
+
+ # Handle backpressure, if it exists.
+ try:
+ self._handle_pressure()
+ except Exception:
+ # If handling backpressure fails,clear the buffer and log the
+ # exception.
+ self._buffer.clear()
+ self._logger.failure("Failed clearing backpressure")
+
+ # Try and write immediately.
+ self._connect()
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 1b8916cfa2..9b46956ca9 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -18,26 +18,11 @@ Log formatters that output terse JSON.
"""
import json
-import sys
-import traceback
-from collections import deque
-from ipaddress import IPv4Address, IPv6Address, ip_address
-from math import floor
-from typing import IO, Optional
+from typing import IO
-import attr
-from zope.interface import implementer
+from twisted.logger import FileLogObserver
-from twisted.application.internet import ClientService
-from twisted.internet.defer import Deferred
-from twisted.internet.endpoints import (
- HostnameEndpoint,
- TCP4ClientEndpoint,
- TCP6ClientEndpoint,
-)
-from twisted.internet.interfaces import IPushProducer, ITransport
-from twisted.internet.protocol import Factory, Protocol
-from twisted.logger import FileLogObserver, ILogObserver, Logger
+from synapse.logging._remote import TCPLogObserver
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
@@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
return FileLogObserver(outFile, formatEvent)
-@attr.s
-@implementer(IPushProducer)
-class LogProducer:
+def TerseJSONToTCPLogObserver(
+ hs, host: str, port: int, metadata: dict, maximum_buffer: int
+) -> FileLogObserver:
"""
- An IPushProducer that writes logs from its buffer to its transport when it
- is resumed.
-
- Args:
- buffer: Log buffer to read logs from.
- transport: Transport to write to.
- """
-
- transport = attr.ib(type=ITransport)
- _buffer = attr.ib(type=deque)
- _paused = attr.ib(default=False, type=bool, init=False)
-
- def pauseProducing(self):
- self._paused = True
-
- def stopProducing(self):
- self._paused = True
- self._buffer = deque()
-
- def resumeProducing(self):
- self._paused = False
-
- while self._paused is False and (self._buffer and self.transport.connected):
- try:
- event = self._buffer.popleft()
- self.transport.write(_encoder.encode(event).encode("utf8"))
- self.transport.write(b"\n")
- except Exception:
- # Something has gone wrong writing to the transport -- log it
- # and break out of the while.
- traceback.print_exc(file=sys.__stderr__)
- break
-
-
-@attr.s
-@implementer(ILogObserver)
-class TerseJSONToTCPLogObserver:
- """
- An IObserver that writes JSON logs to a TCP target.
+ A log observer that formats events to a flattened JSON representation.
Args:
hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
- metadata: Metadata to be added to each log entry.
+ metadata: Metadata to be added to each log object.
+ maximum_buffer: The maximum buffer size.
"""
- hs = attr.ib()
- host = attr.ib(type=str)
- port = attr.ib(type=int)
- metadata = attr.ib(type=dict)
- maximum_buffer = attr.ib(type=int)
- _buffer = attr.ib(default=attr.Factory(deque), type=deque)
- _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
- _logger = attr.ib(default=attr.Factory(Logger))
- _producer = attr.ib(default=None, type=Optional[LogProducer])
-
- def start(self) -> None:
-
- # Connect without DNS lookups if it's a direct IP.
- try:
- ip = ip_address(self.host)
- if isinstance(ip, IPv4Address):
- endpoint = TCP4ClientEndpoint(
- self.hs.get_reactor(), self.host, self.port
- )
- elif isinstance(ip, IPv6Address):
- endpoint = TCP6ClientEndpoint(
- self.hs.get_reactor(), self.host, self.port
- )
- except ValueError:
- endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
-
- factory = Factory.forProtocol(Protocol)
- self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
- self._service.startService()
- self._connect()
-
- def stop(self):
- self._service.stopService()
-
- def _connect(self) -> None:
- """
- Triggers an attempt to connect then write to the remote if not already writing.
- """
- if self._connection_waiter:
- return
-
- self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
- @self._connection_waiter.addErrback
- def fail(r):
- r.printTraceback(file=sys.__stderr__)
- self._connection_waiter = None
- self._connect()
-
- @self._connection_waiter.addCallback
- def writer(r):
- # We have a connection. If we already have a producer, and its
- # transport is the same, just trigger a resumeProducing.
- if self._producer and r.transport is self._producer.transport:
- self._producer.resumeProducing()
- self._connection_waiter = None
- return
-
- # If the producer is still producing, stop it.
- if self._producer:
- self._producer.stopProducing()
-
- # Make a new producer and start it.
- self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
- r.transport.registerProducer(self._producer, True)
- self._producer.resumeProducing()
- self._connection_waiter = None
-
- def _handle_pressure(self) -> None:
- """
- Handle backpressure by shedding events.
-
- The buffer will, in this order, until the buffer is below the maximum:
- - Shed DEBUG events
- - Shed INFO events
- - Shed the middle 50% of the events.
- """
- if len(self._buffer) <= self.maximum_buffer:
- return
-
- # Strip out DEBUGs
- self._buffer = deque(
- filter(lambda event: event["level"] != "DEBUG", self._buffer)
- )
-
- if len(self._buffer) <= self.maximum_buffer:
- return
-
- # Strip out INFOs
- self._buffer = deque(
- filter(lambda event: event["level"] != "INFO", self._buffer)
- )
-
- if len(self._buffer) <= self.maximum_buffer:
- return
-
- # Cut the middle entries out
- buffer_split = floor(self.maximum_buffer / 2)
-
- old_buffer = self._buffer
- self._buffer = deque()
-
- for i in range(buffer_split):
- self._buffer.append(old_buffer.popleft())
-
- end_buffer = []
- for i in range(buffer_split):
- end_buffer.append(old_buffer.pop())
-
- self._buffer.extend(reversed(end_buffer))
-
- def __call__(self, event: dict) -> None:
- flattened = flatten_event(event, self.metadata, include_time=True)
- self._buffer.append(flattened)
-
- # Handle backpressure, if it exists.
- try:
- self._handle_pressure()
- except Exception:
- # If handling backpressure fails,clear the buffer and log the
- # exception.
- self._buffer.clear()
- self._logger.failure("Failed clearing backpressure")
+ def formatEvent(_event: dict) -> str:
+ flattened = flatten_event(_event, metadata, include_time=True)
+ return _encoder.encode(flattened) + "\n"
- # Try and write immediately.
- self._connect()
+ return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 5b73463504..ea5f1c7b62 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -24,6 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.opentracing import start_active_span
if TYPE_CHECKING:
import resource
@@ -197,14 +198,14 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
with BackgroundProcessLoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
-
try:
- result = func(*args, **kwargs)
+ with start_active_span(desc, tags={"request_id": context.request}):
+ result = func(*args, **kwargs)
- if inspect.isawaitable(result):
- result = await result
+ if inspect.isawaitable(result):
+ result = await result
- return result
+ return result
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c440f2545c..a701defcdd 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
# dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow.
def __call__(self):
- rules = self.cache.get(self.room_id, None, update_metrics=False)
+ rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 455a1acb46..155791b754 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -387,8 +387,8 @@ class Mailer:
return ret
async def get_message_vars(self, notif, event, room_state_ids):
- if event.type != EventTypes.Message:
- return
+ if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
+ return None
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
sender_state_event = await self.store.get_event(sender_state_event_id)
@@ -399,10 +399,8 @@ class Mailer:
# sender_hash % the number of default images to choose from
sender_hash = string_ordinal_total(event.sender)
- msgtype = event.content.get("msgtype")
-
ret = {
- "msgtype": msgtype,
+ "event_type": event.type,
"is_historical": event.event_id != notif["event_id"],
"id": event.event_id,
"ts": event.origin_server_ts,
@@ -411,6 +409,14 @@ class Mailer:
"sender_hash": sender_hash,
}
+ # Encrypted messages don't have any additional useful information.
+ if event.type == EventTypes.Encrypted:
+ return ret
+
+ msgtype = event.content.get("msgtype")
+
+ ret["msgtype"] = msgtype
+
if msgtype == "m.text":
self.add_text_message_vars(ret, event)
elif msgtype == "m.image":
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 3a68ce636f..2ce9e444ab 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,11 +16,10 @@
import logging
import re
-from typing import Any, Dict, List, Optional, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase
from synapse.types import UserID
-from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -174,20 +173,21 @@ class PushRuleEvaluatorForEvent:
# Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None)
if not r:
- r = re.escape(display_name)
- r = _re_word_boundary(r)
- r = re.compile(r, flags=re.IGNORECASE)
+ r1 = re.escape(display_name)
+ r1 = _re_word_boundary(r1)
+ r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
- return r.search(body)
+ return bool(r.search(body))
def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(50000)
-register_cache("cache", "regex_push_cache", regex_cache)
+regex_cache = LruCache(
+ 50000, "regex_push_cache"
+) # type: LruCache[Tuple[str, bool, bool], Pattern]
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@@ -205,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
if not r:
r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r
- return r.search(value)
+ return bool(r.search(value))
except re.error:
logger.warning("Failed to parse glob to regex: %r", glob)
return False
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 4b0ea0cc01..0f5b7adef7 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
from ._base import BaseSlavedStore
@@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self.client_ip_last_seen = DeferredCache(
- name="client_ip_last_seen", keylen=4, max_entries=50000
- ) # type: DeferredCache[tuple, int]
+ self.client_ip_last_seen = LruCache(
+ cache_name="client_ip_last_seen", keylen=4, max_size=50000
+ ) # type: LruCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
@@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
- self.client_ip_last_seen.prefill(key, now)
+ self.client_ip_last_seen.set(key, now)
self.hs.get_tcp_replication().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now
diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html
index 1a6c70b562..6d76064d13 100644
--- a/synapse/res/templates/notif.html
+++ b/synapse/res/templates/notif.html
@@ -1,41 +1,47 @@
-{% for message in notif.messages %}
+{%- for message in notif.messages %}
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
<td class="sender_avatar">
- {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
- {% if message.sender_avatar_url %}
+ {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ {%- if message.sender_avatar_url %}
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
- {% else %}
- {% if message.sender_hash % 3 == 0 %}
+ {%- else %}
+ {%- if message.sender_hash % 3 == 0 %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" />
- {% elif message.sender_hash % 3 == 1 %}
+ {%- elif message.sender_hash % 3 == 1 %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" />
- {% else %}
+ {%- else %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" />
- {% endif %}
- {% endif %}
- {% endif %}
+ {%- endif %}
+ {%- endif %}
+ {%- endif %}
</td>
<td class="message_contents">
- {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
- <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
- {% endif %}
+ {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
+ <div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div>
+ {%- endif %}
<div class="message_body">
- {% if message.msgtype == "m.text" %}
- {{ message.body_text_html }}
- {% elif message.msgtype == "m.emote" %}
- {{ message.body_text_html }}
- {% elif message.msgtype == "m.notice" %}
- {{ message.body_text_html }}
- {% elif message.msgtype == "m.image" %}
- <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
- {% elif message.msgtype == "m.file" %}
- <span class="filename">{{ message.body_text_plain }}</span>
- {% endif %}
+ {%- if message.event_type == "m.room.encrypted" %}
+ An encrypted message.
+ {%- elif message.event_type == "m.room.message" %}
+ {%- if message.msgtype == "m.text" %}
+ {{ message.body_text_html }}
+ {%- elif message.msgtype == "m.emote" %}
+ {{ message.body_text_html }}
+ {%- elif message.msgtype == "m.notice" %}
+ {{ message.body_text_html }}
+ {%- elif message.msgtype == "m.image" %}
+ <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
+ {%- elif message.msgtype == "m.file" %}
+ <span class="filename">{{ message.body_text_plain }}</span>
+ {%- else %}
+ A message with unrecognised content.
+ {%- endif %}
+ {%- endif %}
</div>
</td>
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
</tr>
-{% endfor %}
+{%- endfor %}
<tr class="notif_link">
<td></td>
<td>
diff --git a/synapse/res/templates/notif.txt b/synapse/res/templates/notif.txt
index a37bee9833..1ee7da3c50 100644
--- a/synapse/res/templates/notif.txt
+++ b/synapse/res/templates/notif.txt
@@ -1,16 +1,22 @@
-{% for message in notif.messages %}
-{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
-{% if message.msgtype == "m.text" %}
+{%- for message in notif.messages %}
+{%- if message.event_type == "m.room.encrypted" %}
+An encrypted message.
+{%- elif message.event_type == "m.room.message" %}
+{%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
+{%- if message.msgtype == "m.text" %}
{{ message.body_text_plain }}
-{% elif message.msgtype == "m.emote" %}
+{%- elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }}
-{% elif message.msgtype == "m.notice" %}
+{%- elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }}
-{% elif message.msgtype == "m.image" %}
+{%- elif message.msgtype == "m.image" %}
{{ message.body_text_plain }}
-{% elif message.msgtype == "m.file" %}
+{%- elif message.msgtype == "m.file" %}
{{ message.body_text_plain }}
-{% endif %}
-{% endfor %}
+{%- else %}
+A message with unrecognised content.
+{%- endif %}
+{%- endif %}
+{%- endfor %}
View {{ room.title }} at {{ notif.link }}
diff --git a/synapse/res/templates/notif_mail.html b/synapse/res/templates/notif_mail.html
index a2dfeb9e9f..27d4182790 100644
--- a/synapse/res/templates/notif_mail.html
+++ b/synapse/res/templates/notif_mail.html
@@ -2,8 +2,8 @@
<html lang="en">
<head>
<style type="text/css">
- {% include 'mail.css' without context %}
- {% include "mail-%s.css" % app_name ignore missing without context %}
+ {%- include 'mail.css' without context %}
+ {%- include "mail-%s.css" % app_name ignore missing without context %}
</style>
</head>
<body>
@@ -18,21 +18,21 @@
<div class="summarytext">{{ summary_text }}</div>
</td>
<td class="logo">
- {% if app_name == "Riot" %}
+ {%- if app_name == "Riot" %}
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
- {% elif app_name == "Vector" %}
+ {%- elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
- {% elif app_name == "Element" %}
+ {%- elif app_name == "Element" %}
<img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
- {% else %}
+ {%- else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
- {% endif %}
+ {%- endif %}
</td>
</tr>
</table>
- {% for room in rooms %}
- {% include 'room.html' with context %}
- {% endfor %}
+ {%- for room in rooms %}
+ {%- include 'room.html' with context %}
+ {%- endfor %}
<div class="footer">
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
<br/>
@@ -41,12 +41,12 @@
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
- {% if reason.last_sent_ts %}
+ {%- if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
- {% else %}
+ {%- else %}
and we don't have a last time we sent a mail for this room.
- {% endif %}
+ {%- endif %}
</div>
</div>
</td>
diff --git a/synapse/res/templates/notif_mail.txt b/synapse/res/templates/notif_mail.txt
index 24843042a5..df3c253979 100644
--- a/synapse/res/templates/notif_mail.txt
+++ b/synapse/res/templates/notif_mail.txt
@@ -2,9 +2,9 @@ Hi {{ user_display_name }},
{{ summary_text }}
-{% for room in rooms %}
-{% include 'room.txt' with context %}
-{% endfor %}
+{%- for room in rooms %}
+{%- include 'room.txt' with context %}
+{%- endfor %}
You can disable these notifications at {{ unsubscribe_link }}
diff --git a/synapse/res/templates/room.html b/synapse/res/templates/room.html
index b8525fef88..4fc6f6ac9b 100644
--- a/synapse/res/templates/room.html
+++ b/synapse/res/templates/room.html
@@ -1,23 +1,23 @@
<table class="room">
<tr class="room_header">
<td class="room_avatar">
- {% if room.avatar_url %}
+ {%- if room.avatar_url %}
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
- {% else %}
- {% if room.hash % 3 == 0 %}
+ {%- else %}
+ {%- if room.hash % 3 == 0 %}
<img alt="" src="https://riot.im/img/external/avatar-1.png" />
- {% elif room.hash % 3 == 1 %}
+ {%- elif room.hash % 3 == 1 %}
<img alt="" src="https://riot.im/img/external/avatar-2.png" />
- {% else %}
+ {%- else %}
<img alt="" src="https://riot.im/img/external/avatar-3.png" />
- {% endif %}
- {% endif %}
+ {%- endif %}
+ {%- endif %}
</td>
<td class="room_name" colspan="2">
{{ room.title }}
</td>
</tr>
- {% if room.invite %}
+ {%- if room.invite %}
<tr>
<td></td>
<td>
@@ -25,9 +25,9 @@
</td>
<td></td>
</tr>
- {% else %}
- {% for notif in room.notifs %}
- {% include 'notif.html' with context %}
- {% endfor %}
- {% endif %}
+ {%- else %}
+ {%- for notif in room.notifs %}
+ {%- include 'notif.html' with context %}
+ {%- endfor %}
+ {%- endif %}
</table>
diff --git a/synapse/res/templates/room.txt b/synapse/res/templates/room.txt
index 84648c710e..df841e9e6f 100644
--- a/synapse/res/templates/room.txt
+++ b/synapse/res/templates/room.txt
@@ -1,9 +1,9 @@
{{ room.title }}
-{% if room.invite %}
+{%- if room.invite %}
You've been invited, join at {{ room.link }}
-{% else %}
- {% for notif in room.notifs %}
- {% include 'notif.txt' with context %}
- {% endfor %}
-{% endif %}
+{%- else %}
+ {%- for notif in room.notifs %}
+ {%- include 'notif.txt' with context %}
+ {%- endfor %}
+{%- endif %}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7deb9300d..b82a4e978a 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -110,6 +110,8 @@ class LoginRestServlet(RestServlet):
({"type": t} for t in self.auth_handler.get_supported_login_types())
)
+ flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
+
return 200, {"flows": flows}
def on_OPTIONS(self, request: SynapseRequest):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ab49d227de..2b196ded1b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta):
"""
try:
- if key is None:
- getattr(self, cache_name).invalidate_all()
- else:
- getattr(self, cache_name).invalidate(tuple(key))
+ cache = getattr(self, cache_name)
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
- pass
+ return
+
+ if key is None:
+ cache.invalidate_all()
+ else:
+ cache.invalidate(tuple(key))
def db_to_json(db_content):
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 763722d6bc..0217e63108 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -160,7 +160,7 @@ class LoggingDatabaseConnection:
self.conn.__enter__()
return self
- def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class.
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index ec8260e906..2408432738 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
@@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- self.client_ip_last_seen = DeferredCache(
- name="client_ip_last_seen", keylen=4, max_entries=50000
+ self.client_ip_last_seen = LruCache(
+ cache_name="client_ip_last_seen", keylen=4, max_size=50000
)
super().__init__(database, db_conn, hs)
@@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
- self.client_ip_last_seen.prefill(key, now)
+ self.client_ip_last_seen.set(key, now)
self._batch_row_update[key] = (user_agent, device_id, now)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e662a20d24..dfb4f87b8f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -34,8 +34,8 @@ from synapse.storage.database import (
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
-from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
- self.device_id_exists_cache = DeferredCache(
- name="device_id_exists", keylen=2, max_entries=10000
+ self.device_id_exists_cache = LruCache(
+ cache_name="device_id_exists", keylen=2, max_size=10000
)
async def store_device(
@@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
- self.device_id_exists_cache.prefill(key, True)
+ self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
raise
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ba3b1769b0..87808c1483 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1051,9 +1051,7 @@ class PersistEventsStore:
def prefill():
for cache_entry in to_prefill:
- self.store._get_event_cache.prefill(
- (cache_entry[0].event_id,), cache_entry
- )
+ self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
txn.call_after(prefill)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ff150f0be7..6e7f16f39c 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -33,7 +33,10 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
@@ -42,8 +45,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
-from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached
+from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -137,20 +140,16 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
- if not hs.config.worker.worker_app:
+ if hs.config.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
- run_as_background_process,
- 5 * 60 * 1000,
- "_cleanup_old_transaction_ids",
- self._cleanup_old_transaction_ids,
+ self._cleanup_old_transaction_ids, 5 * 60 * 1000,
)
- self._get_event_cache = DeferredCache(
- "*getEvent*",
+ self._get_event_cache = LruCache(
+ cache_name="*getEvent*",
keylen=3,
- max_entries=hs.config.caches.event_cache_size,
- apply_cache_factor_from_config=False,
+ max_size=hs.config.caches.event_cache_size,
)
self._event_fetch_lock = threading.Condition()
@@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore):
event=original_ev, redacted_event=redacted_event
)
- self._get_event_cache.prefill((event_id,), cache_entry)
+ self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry
return result_map
@@ -1375,6 +1374,7 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
+ @wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
"""Cleans out transaction id mappings older than 24hrs.
"""
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 1681caa1f0..a6d1eb908a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
- self, user_localpart: str, new_displayname: str
+ self, user_localpart: str, new_displayname: Optional[str]
) -> None:
await self.db_pool.simple_update_one(
table="profiles",
@@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
- ) -> Dict[str, str]:
+ ) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`
"""
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df8609b97b..7997242d90 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
lock=False,
)
- user_has_pusher = self.get_if_user_has_pusher.cache.get(
+ user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
(user_id,), None, update_metrics=False
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 5cdf16521c..ca7917c989 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
if receipt_type != "m.read":
return
- # Returns either an ObservableDeferred or the raw result
- res = self.get_users_with_read_receipts_in_room.cache.get(
+ res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
room_id, None, update_metrics=False
)
- # first handle the ObservableDeferred case
- if isinstance(res, ObservableDeferred):
- if res.has_called():
- res = res.get_result()
- else:
- res = None
-
if res and user_id in res:
# We'd only be adding to the set, so no point invalidating if the
# user is already there
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 20fcdaa529..01d9dbb36f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -20,7 +20,10 @@ from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
@@ -67,16 +70,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
- run_as_background_process,
- 60 * 1000,
- "_count_known_servers",
- self._count_known_servers,
+ self._count_known_servers, 60 * 1000,
)
self.hs.get_clock().call_later(
- 1000,
- run_as_background_process,
- "_count_known_servers",
- self._count_known_servers,
+ 1000, self._count_known_servers,
)
LaterGauge(
"synapse_federation_known_servers",
@@ -85,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
+ @wrap_as_background_process("_count_known_servers")
async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -531,7 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids`
if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get(
+ prev_res = self._get_joined_users_from_context.cache.get_immediate(
(room_id, context.prev_group), None
)
if prev_res and isinstance(prev_res, dict):
diff --git a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
index 20f5a95a24..7b84a207fd 100644
--- a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/58/21as_device_stream.sql
@@ -13,6 +13,5 @@
* limitations under the License.
*/
-ALTER TABLE application_services_state
- ADD COLUMN read_receipt_stream_id INT,
- ADD COLUMN presence_stream_id INT;
\ No newline at end of file
+ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
+ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
new file mode 100644
index 0000000000..01ea6eddcf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/21drop_device_max_stream_id.sql
@@ -0,0 +1 @@
+DROP TABLE device_max_stream_id;
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 970bb1b9da..9cadcba18f 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Iterator, List, Tuple
+from typing import Any, Iterable, Iterator, List, Optional, Tuple
from typing_extensions import Protocol
@@ -65,5 +65,5 @@ class Connection(Protocol):
def __enter__(self) -> "Connection":
...
- def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
...
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8fc05be278..89f0b38535 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -16,7 +16,7 @@
import logging
from sys import intern
-from typing import Callable, Dict, Optional
+from typing import Callable, Dict, Optional, Sized
import attr
from prometheus_client.core import Gauge
@@ -92,7 +92,7 @@ class CacheMetric:
def register_cache(
cache_type: str,
cache_name: str,
- cache,
+ cache: Sized,
collect_callback: Optional[Callable] = None,
resizable: bool = True,
resize_callback: Optional[Callable] = None,
@@ -100,12 +100,15 @@ def register_cache(
"""Register a cache object for metric collection and resizing.
Args:
- cache_type
+ cache_type: a string indicating the "type" of the cache. This is used
+ only for deduplication so isn't too important provided it's constant.
cache_name: name of the cache
- cache: cache itself
+ cache: cache itself, which must implement __len__(), and may optionally implement
+ a max_size property
collect_callback: If given, a function which is called during metric
collection to update additional metrics.
- resizable: Whether this cache supports being resized.
+ resizable: Whether this cache supports being resized, in which case either
+ resize_callback must be provided, or the cache must support set_max_size().
resize_callback: A function which can be called to resize the cache.
Returns:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index f728cd2cf2..601305487c 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -17,14 +17,23 @@
import enum
import threading
-from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
+from typing import (
+ Callable,
+ Generic,
+ Iterable,
+ MutableMapping,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
from prometheus_client import Gauge
from twisted.internet import defer
+from twisted.python import failure
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -34,7 +43,7 @@ cache_pending_metric = Gauge(
["name"],
)
-
+T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
@@ -49,15 +58,12 @@ class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
- may return an ObservableDeferred.
+ will return a Deferred.
"""
__slots__ = (
"cache",
- "name",
- "keylen",
"thread",
- "metrics",
"_pending_deferred_cache",
)
@@ -89,37 +95,27 @@ class DeferredCache(Generic[KT, VT]):
cache_type()
) # type: MutableMapping[KT, CacheEntry]
+ def metrics_cb():
+ cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
+
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
+ cache_name=name,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
- evicted_callback=self._on_evicted,
+ metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
- )
+ ) # type: LruCache[KT, VT]
- self.name = name
- self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
- self.metrics = register_cache(
- "cache",
- name,
- self.cache,
- collect_callback=self._metrics_collection_callback,
- )
@property
def max_entries(self):
return self.cache.max_size
- def _on_evicted(self, evicted_count):
- self.metrics.inc_evictions(evicted_count)
-
- def _metrics_collection_callback(self):
- cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
-
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -133,62 +129,113 @@ class DeferredCache(Generic[KT, VT]):
def get(
self,
key: KT,
- default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
- ):
+ ) -> defer.Deferred:
"""Looks the key up in the caches.
+ For symmetry with set(), this method does *not* follow the synapse logcontext
+ rules: the logcontext will not be cleared on return, and the Deferred will run
+ its callbacks in the sentinel context. In other words: wrap the result with
+ make_deferred_yieldable() before `await`ing it.
+
Args:
- key(tuple)
- default: What is returned if key is not in the caches. If not
- specified then function throws KeyError instead
- callback(fn): Gets called when the entry in the cache is invalidated
+ key:
+ callback: Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either an ObservableDeferred or the result itself
+ A Deferred which completes with the result. Note that this may later fail
+ if there is an ongoing set() operation which later completes with a failure.
+
+ Raises:
+ KeyError if the key is not found in the cache
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
- self.metrics.inc_hits()
- return val.deferred
-
- val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
- if val is not _Sentinel.sentinel:
- self.metrics.inc_hits()
- return val
+ m = self.cache.metrics
+ assert m # we always have a name, so should always have metrics
+ m.inc_hits()
+ return val.deferred.observe()
- if update_metrics:
- self.metrics.inc_misses()
-
- if default is _Sentinel.sentinel:
+ val2 = self.cache.get(
+ key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
+ )
+ if val2 is _Sentinel.sentinel:
raise KeyError()
else:
- return default
+ return defer.succeed(val2)
+
+ def get_immediate(
+ self, key: KT, default: T, update_metrics: bool = True
+ ) -> Union[VT, T]:
+ """If we have a *completed* cached value, return it."""
+ return self.cache.get(key, default, update_metrics=update_metrics)
def set(
self,
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
- ) -> ObservableDeferred:
+ ) -> defer.Deferred:
+ """Adds a new entry to the cache (or updates an existing one).
+
+ The given `value` *must* be a Deferred.
+
+ First any existing entry for the same key is invalidated. Then a new entry
+ is added to the cache for the given key.
+
+ Until the `value` completes, calls to `get()` for the key will also result in an
+ incomplete Deferred, which will ultimately complete with the same result as
+ `value`.
+
+ If `value` completes successfully, subsequent calls to `get()` will then return
+ a completed deferred with the same result. If it *fails*, the cache is
+ invalidated and subequent calls to `get()` will raise a KeyError.
+
+ If another call to `set()` happens before `value` completes, then (a) any
+ invalidation callbacks registered in the interim will be called, (b) any
+ `get()`s in the interim will continue to complete with the result from the
+ *original* `value`, (c) any future calls to `get()` will complete with the
+ result from the *new* `value`.
+
+ It is expected that `value` does *not* follow the synapse logcontext rules - ie,
+ if it is incomplete, it runs its callbacks in the sentinel context.
+
+ Args:
+ key: Key to be set
+ value: a deferred which will complete with a result to add to the cache
+ callback: An optional callback to be called when the entry is invalidated
+ """
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
- observable = ObservableDeferred(value, consumeErrors=True)
- observer = observable.observe()
- entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
+ # XXX: why don't we invalidate the entry in `self.cache` yet?
+
+ # we can save a whole load of effort if the deferred is ready.
+ if value.called:
+ result = value.result
+ if not isinstance(result, failure.Failure):
+ self.cache.set(key, result, callbacks)
+ return value
+
+ # otherwise, we'll add an entry to the _pending_deferred_cache for now,
+ # and add callbacks to add it to the cache properly later.
+
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = observable.observe()
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
+
self._pending_deferred_cache[key] = entry
def compare_and_pop():
@@ -232,7 +279,9 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
- return observable
+
+ # we return a new Deferred which will be called before any subsequent observers.
+ return observable.observe()
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
@@ -257,11 +306,12 @@ class DeferredCache(Generic[KT, VT]):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+ key = cast(KT, key)
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
- entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
+ entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1f43886804..5d7fffee66 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
- ) # type: DeferredCache[Tuple, Any]
+ ) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
@@ -202,32 +201,20 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_key = get_cache_key(args, kwargs)
- # Add our own `cache_context` to argument list if the wrapped function
- # has asked for one
- if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
-
try:
- cached_result_d = cache.get(cache_key, callback=invalidate_callback)
-
- if isinstance(cached_result_d, ObservableDeferred):
- observer = cached_result_d.observe()
- else:
- observer = defer.succeed(cached_result_d)
-
+ ret = cache.get(cache_key, callback=invalidate_callback)
except KeyError:
- ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+ # Add our own `cache_context` to argument list if the wrapped function
+ # has asked for one
+ if self.add_cache_context:
+ kwargs["cache_context"] = _CacheContext.get_instance(
+ cache, cache_key
+ )
- def onErr(f):
- cache.invalidate(cache_key)
- return f
-
- ret.addErrback(onErr)
-
- result_d = cache.set(cache_key, ret, callback=invalidate_callback)
- observer = result_d.observe()
+ ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
+ ret = cache.set(cache_key, ret, callback=invalidate_callback)
- return make_deferred_yieldable(observer)
+ return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@@ -286,7 +273,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name)
- cache = cached_method.cache
+ cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
num_args = cached_method.num_args
@functools.wraps(self.orig)
@@ -326,14 +313,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not isinstance(res, ObservableDeferred):
- results[arg] = res
- elif not res.has_succeeded():
- res = res.observe()
+ if not res.called:
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
- results[arg] = res.get_result()
+ results[arg] = res.result
except KeyError:
missing.add(arg)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8592b93689..588d2d49f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -12,15 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import logging
import threading
from collections import namedtuple
+from typing import Any
from synapse.util.caches.lrucache import LruCache
-from . import register_cache
-
logger = logging.getLogger(__name__)
@@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
return len(self.value)
+class _Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
+
+
class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
- self.cache = LruCache(max_size=max_entries, size_callback=len)
+ self.cache = LruCache(
+ max_size=max_entries, cache_name=name, size_callback=len
+ ) # type: LruCache[Any, DictionaryEntry]
self.name = name
self.sequence = 0
self.thread = None
- # caches_by_name[name] = self.cache
-
- class Sentinel:
- __slots__ = []
-
- self.sentinel = Sentinel()
- self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -80,10 +80,8 @@ class DictionaryCache:
Returns:
DictionaryEntry
"""
- entry = self.cache.get(key, self.sentinel)
- if entry is not self.sentinel:
- self.metrics.inc_hits()
-
+ entry = self.cache.get(key, _Sentinel.sentinel)
+ if entry is not _Sentinel.sentinel:
if dict_keys is None:
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)
@@ -95,7 +93,6 @@ class DictionaryCache:
{k: entry.value[k] for k in dict_keys if k in entry.value},
)
- self.metrics.inc_misses()
return DictionaryEntry(False, set(), {})
def invalidate(self, key):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 33eae2b7c4..60bb6ff642 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,11 +15,35 @@
import threading
from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import Literal
from synapse.config import cache as cache_config
+from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
def enumerate_leaves(node, depth):
if depth == 0:
@@ -41,29 +65,31 @@ class _Node:
self.callbacks = callbacks
-class LruCache:
+class LruCache(Generic[KT, VT]):
"""
- Least-recently-used cache.
+ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
+
Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples.
-
- Can also set callbacks on objects when getting/setting which are fired
- when that key gets invalidated/evicted.
"""
def __init__(
self,
max_size: int,
+ cache_name: Optional[str] = None,
keylen: int = 1,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable] = None,
- evicted_callback: Optional[Callable] = None,
+ metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
max_size: The maximum amount of entries the cache can hold
+ cache_name: The name of this cache, for the prometheus metrics. If unset,
+ no metrics will be reported on this cache.
+
keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`.
@@ -73,9 +99,13 @@ class LruCache:
size_callback (func(V) -> int | None):
- evicted_callback (func(int)|None):
- if not None, called on eviction with the size of the evicted
- entry
+ metrics_collection_callback:
+ metrics collection callback. This is called early in the metrics
+ collection process, before any of the metrics registered with the
+ prometheus Registry are collected, so can be used to update any dynamic
+ metrics.
+
+ Ignored if cache_name is None.
apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config
@@ -94,6 +124,23 @@ class LruCache:
else:
self.max_size = int(max_size)
+ # register_cache might call our "set_cache_factor" callback; there's nothing to
+ # do yet when we get resized.
+ self._on_resize = None # type: Optional[Callable[[],None]]
+
+ if cache_name is not None:
+ metrics = register_cache(
+ "lru_cache",
+ cache_name,
+ self,
+ collect_callback=metrics_collection_callback,
+ ) # type: Optional[CacheMetric]
+ else:
+ metrics = None
+
+ # this is exposed for access from outside this class
+ self.metrics = metrics
+
list_root = _Node(None, None, None, None)
list_root.next_node = list_root
list_root.prev_node = list_root
@@ -105,16 +152,16 @@ class LruCache:
todelete = list_root.prev_node
evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
- if evicted_callback:
- evicted_callback(evicted_len)
+ if metrics:
+ metrics.inc_evictions(evicted_len)
- def synchronized(f):
+ def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
- return inner
+ return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@@ -168,18 +215,45 @@ class LruCache:
node.callbacks.clear()
return deleted_len
+ @overload
+ def cache_get(
+ key: KT,
+ default: Literal[None] = None,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_get(
+ key: KT,
+ default: T,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_get(key, default=None, callbacks=[]):
+ def cache_get(
+ key: KT,
+ default: Optional[T] = None,
+ callbacks: Iterable[Callable[[], None]] = [],
+ update_metrics: bool = True,
+ ):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
node.callbacks.update(callbacks)
+ if update_metrics and metrics:
+ metrics.inc_hits()
return node.value
else:
+ if update_metrics and metrics:
+ metrics.inc_misses()
return default
@synchronized
- def cache_set(key, value, callbacks=[]):
+ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@@ -208,7 +282,7 @@ class LruCache:
evict()
@synchronized
- def cache_set_default(key, value):
+ def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@@ -217,8 +291,16 @@ class LruCache:
evict()
return value
+ @overload
+ def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_pop(key: KT, default: T) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_pop(key, default=None):
+ def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None)
if node:
delete_node(node)
@@ -228,18 +310,18 @@ class LruCache:
return default
@synchronized
- def cache_del_multi(key):
+ def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
- for leaf in enumerate_leaves(popped, keylen - len(key)):
+ for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
- def cache_clear():
+ def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@@ -250,15 +332,21 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
- def cache_contains(key):
+ def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()
+
+ # make sure that we clear out any excess entries after we get resized.
self._on_resize = evict
+
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
+ # `invalidate` is exposed for consistency with DeferredCache, so that it can be
+ # invalidated by the cache invalidation replication stream.
+ self.invalidate = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = synchronized(cache_len)
@@ -302,6 +390,7 @@ class LruCache:
new_size = int(self._original_max_size * factor)
if new_size != self.max_size:
self.max_size = new_size
- self._on_resize()
+ if self._on_resize:
+ self._on_resize()
return True
return False
diff --git a/synmark/__init__.py b/synmark/__init__.py
index 53698bd5ab..09bc7e7927 100644
--- a/synmark/__init__.py
+++ b/synmark/__init__.py
@@ -15,7 +15,10 @@
import sys
-from twisted.internet import epollreactor
+try:
+ from twisted.internet.epollreactor import EPollReactor as Reactor
+except ImportError:
+ from twisted.internet.pollreactor import PollReactor as Reactor
from twisted.internet.main import installReactor
from synapse.config.homeserver import HomeServerConfig
@@ -41,7 +44,7 @@ async def make_homeserver(reactor, config=None):
config_obj = HomeServerConfig()
config_obj.parse_config_dict(config, "", "")
- hs = await setup_test_homeserver(
+ hs = setup_test_homeserver(
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
)
stor = hs.get_datastore()
@@ -63,7 +66,7 @@ def make_reactor():
Instantiate and install a Twisted reactor suitable for testing (i.e. not the
default global one).
"""
- reactor = epollreactor.EPollReactor()
+ reactor = Reactor()
if "twisted.internet.reactor" in sys.modules:
del sys.modules["twisted.internet.reactor"]
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 2acb8b7603..97f8cad0dd 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -260,6 +260,31 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
+ def test_send_large_txns(self):
+ srv_1_defer = defer.Deferred()
+ srv_2_defer = defer.Deferred()
+ send_return_list = [srv_1_defer, srv_2_defer]
+
+ def do_send(x, y, z):
+ return make_deferred_yieldable(send_return_list.pop(0))
+
+ self.txn_ctrl.send = Mock(side_effect=do_send)
+
+ service = Mock(id=4, name="service")
+ event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
+ for event in event_list:
+ self.queuer.enqueue_event(service, event)
+
+ # Expect the first event to be sent immediately.
+ self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
+ srv_1_defer.callback(service)
+ # Then send the next 100 events
+ self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
+ srv_2_defer.callback(service)
+ # Then the final 99 events
+ self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
+ self.assertEquals(3, self.txn_ctrl.send.call_count)
+
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
@@ -296,3 +321,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
self.assertEquals(2, self.txn_ctrl.send.call_count)
+
+ def test_send_large_txns_ephemeral(self):
+ d = defer.Deferred()
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y, z: make_deferred_yieldable(d)
+ )
+ # Expect the event to be sent immediately.
+ service = Mock(id=4, name="service")
+ first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
+ second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
+ event_list = first_chunk + second_chunk
+ self.queuer.enqueue_ephemeral(service, event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
+ d.callback(service)
+ self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
+ self.assertEquals(2, self.txn_ctrl.send.call_count)
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 4cf81f7128..fd128b88e0 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -78,7 +78,7 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"server_name",
"name",
]
- self.assertEqual(set(log.keys()), set(expected_log_keys))
+ self.assertCountEqual(log.keys(), expected_log_keys)
# It contains the data we expect.
self.assertEqual(log["name"], "wally")
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 3224568640..55545d9341 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -158,8 +158,21 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_encrypted_message(self):
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends some messages
+ self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
def _check_for_mail(self):
- "Check that the user receives an email notification"
+ """Check that the user receives an email notification"""
# Get the stream ordering before it gets sent
pushers = self.get_success(
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index ae2cd67f35..66ac4dbe85 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -352,7 +352,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(request.code, 401)
- @unittest.INFO
def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastore()
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 293ccfba2b..86184f0d2e 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -104,7 +104,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")
- @unittest.INFO
def test_fallback_captcha(self):
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8e69b1e9cc..1ac4ebc61d 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -15,237 +15,9 @@
# limitations under the License.
-from mock import Mock
-
-from twisted.internet import defer
-
-from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached
-
from tests import unittest
-class CacheDecoratorTestCase(unittest.HomeserverTestCase):
- @defer.inlineCallbacks
- def test_passthrough(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- a = A()
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals((yield a.func("bar")), "bar")
-
- @defer.inlineCallbacks
- def test_hit(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- self.assertEquals((yield a.func("foo")), "foo")
- self.assertEquals(callcount[0], 1)
-
- @defer.inlineCallbacks
- def test_invalidate(self):
- callcount = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 1)
-
- a.func.invalidate(("foo",))
-
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
-
- def test_invalidate_missing(self):
- class A:
- @cached()
- def func(self, key):
- return key
-
- A().func.invalidate(("what",))
-
- @defer.inlineCallbacks
- def test_max_entries(self):
- callcount = [0]
-
- class A:
- @cached(max_entries=10)
- def func(self, key):
- callcount[0] += 1
- return key
-
- a = A()
-
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertEquals(callcount[0], 12)
-
- # There must have been at least 2 evictions, meaning if we calculate
- # all 12 values again, we must get called at least 2 more times
- for k in range(0, 12):
- yield a.func(k)
-
- self.assertTrue(
- callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
- )
-
- def test_prefill(self):
- callcount = [0]
-
- d = defer.succeed(123)
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return d
-
- a = A()
-
- a.func.prefill(("foo",), ObservableDeferred(d))
-
- self.assertEquals(a.func("foo").result, d.result)
- self.assertEquals(callcount[0], 0)
-
- @defer.inlineCallbacks
- def test_invalidate_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func.invalidate(("foo",))
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 1)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- @defer.inlineCallbacks
- def test_eviction_context(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached(max_entries=2)
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- yield a.func2("foo")
- yield a.func2("foo2")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func("foo3")
-
- self.assertEquals(callcount[0], 3)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 4)
- self.assertEquals(callcount2[0], 3)
-
- @defer.inlineCallbacks
- def test_double_get(self):
- callcount = [0]
- callcount2 = [0]
-
- class A:
- @cached()
- def func(self, key):
- callcount[0] += 1
- return key
-
- @cached(cache_context=True)
- def func2(self, key, cache_context):
- callcount2[0] += 1
- return self.func(key, on_invalidate=cache_context.invalidate)
-
- a = A()
- a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 1)
-
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
-
- yield a.func2("foo")
- a.func2.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
-
- self.assertEquals(callcount[0], 1)
- self.assertEquals(callcount2[0], 2)
-
- a.func.invalidate(("foo",))
- self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
- yield a.func("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 2)
-
- yield a.func2("foo")
-
- self.assertEquals(callcount[0], 2)
- self.assertEquals(callcount2[0], 3)
-
-
class UpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.storage = hs.get_datastore()
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index 9717be56b6..dadfabd46d 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from functools import partial
from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache
+from tests.unittest import TestCase
-class DeferredCacheTestCase(unittest.TestCase):
+
+class DeferredCacheTestCase(TestCase):
def test_empty(self):
cache = DeferredCache("test")
failed = False
@@ -36,7 +37,118 @@ class DeferredCacheTestCase(unittest.TestCase):
cache = DeferredCache("test")
cache.prefill("foo", 123)
- self.assertEquals(cache.get("foo"), 123)
+ self.assertEquals(self.successResultOf(cache.get("foo")), 123)
+
+ def test_hit_deferred(self):
+ cache = DeferredCache("test")
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d)
+
+ # get should return an incomplete deferred
+ get_d = cache.get("k1")
+ self.assertFalse(get_d.called)
+
+ # add a callback that will make sure that the set_d gets called before the get_d
+ def check1(r):
+ self.assertTrue(set_d.called)
+ return r
+
+ # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
+ # maybe we should fix that?
+ # get_d.addCallback(check1)
+
+ # now fire off all the deferreds
+ origin_d.callback(99)
+ self.assertEqual(self.successResultOf(origin_d), 99)
+ self.assertEqual(self.successResultOf(set_d), 99)
+ self.assertEqual(self.successResultOf(get_d), 99)
+
+ def test_callbacks(self):
+ """Invalidation callbacks are called at the right time"""
+ cache = DeferredCache("test")
+ callbacks = set()
+
+ # start with an entry, with a callback
+ cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+ # now replace that entry with a pending result
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+ # ... and also make a get request
+ get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+ # we don't expect the invalidation callback for the original value to have
+ # been called yet, even though get() will now return a different result.
+ # I'm not sure if that is by design or not.
+ self.assertEqual(callbacks, set())
+
+ # now fire off all the deferreds
+ origin_d.callback(20)
+ self.assertEqual(self.successResultOf(set_d), 20)
+ self.assertEqual(self.successResultOf(get_d), 20)
+
+ # now the original invalidation callback should have been called, but none of
+ # the others
+ self.assertEqual(callbacks, {"prefill"})
+ callbacks.clear()
+
+ # another update should invalidate both the previous results
+ cache.prefill("k1", 30)
+ self.assertEqual(callbacks, {"set", "get"})
+
+ def test_set_fail(self):
+ cache = DeferredCache("test")
+ callbacks = set()
+
+ # start with an entry, with a callback
+ cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
+
+ # now replace that entry with a pending result
+ origin_d = defer.Deferred()
+ set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
+
+ # ... and also make a get request
+ get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
+
+ # none of the callbacks should have been called yet
+ self.assertEqual(callbacks, set())
+
+ # oh noes! fails!
+ e = Exception("oops")
+ origin_d.errback(e)
+ self.assertIs(self.failureResultOf(set_d, Exception).value, e)
+ self.assertIs(self.failureResultOf(get_d, Exception).value, e)
+
+ # the callbacks for the failed requests should have been called.
+ # I'm not sure if this is deliberate or not.
+ self.assertEqual(callbacks, {"get", "set"})
+ callbacks.clear()
+
+ # the old value should still be returned now?
+ get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
+ self.assertEqual(self.successResultOf(get_d2), 10)
+
+ # replacing the value now should run the callbacks for those requests
+ # which got the original result
+ cache.prefill("k1", 30)
+ self.assertEqual(callbacks, {"prefill", "get2"})
+
+ def test_get_immediate(self):
+ cache = DeferredCache("test")
+ d1 = defer.Deferred()
+ cache.set("key1", d1)
+
+ # get_immediate should return default
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 1)
+
+ # now complete the set
+ d1.callback(2)
+
+ # get_immediate should return result
+ v = cache.get_immediate("key1", 1)
+ self.assertEqual(v, 2)
def test_invalidate(self):
cache = DeferredCache("test")
@@ -66,23 +178,24 @@ class DeferredCacheTestCase(unittest.TestCase):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
- # lookup should return observable deferreds
- self.assertFalse(cache.get("key1").has_called())
- self.assertFalse(cache.get("key2").has_called())
+ # lookup should return pending deferreds
+ self.assertFalse(cache.get("key1").called)
+ self.assertFalse(cache.get("key2").called)
# let one of the lookups complete
d2.callback("result2")
- # for now at least, the cache will return real results rather than an
- # observabledeferred
- self.assertEqual(cache.get("key2"), "result2")
+ # now the cache will return a completed deferred
+ self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
# now do the invalidation
cache.invalidate_all()
- # lookup should return none
- self.assertIsNone(cache.get("key1", None))
- self.assertIsNone(cache.get("key2", None))
+ # lookup should fail
+ with self.assertRaises(KeyError):
+ cache.get("key1")
+ with self.assertRaises(KeyError):
+ cache.get("key2")
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
@@ -90,7 +203,8 @@ class DeferredCacheTestCase(unittest.TestCase):
# letting the other lookup complete should do nothing
d1.callback("result1")
- self.assertIsNone(cache.get("key1", None))
+ with self.assertRaises(KeyError):
+ cache.get("key1", None)
def test_eviction(self):
cache = DeferredCache(
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 3d1f960869..2ad08f541b 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Set
import mock
@@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1)
self.failureResultOf(d, SynapseError)
+ def test_cache_with_async_exception(self):
+ """The wrapped function returns a failure
+ """
+
+ class Cls:
+ result = None
+ call_count = 0
+
+ @cached()
+ def fn(self, arg1):
+ self.call_count += 1
+ return self.result
+
+ obj = Cls()
+ callbacks = set() # type: Set[str]
+
+ # set off an asynchronous request
+ obj.result = origin_d = defer.Deferred()
+
+ d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
+ self.assertFalse(d1.called)
+
+ # a second request should also return a deferred, but should not call the
+ # function itself.
+ d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
+ self.assertFalse(d2.called)
+ self.assertEqual(obj.call_count, 1)
+
+ # no callbacks yet
+ self.assertEqual(callbacks, set())
+
+ # the original request fails
+ e = Exception("bzz")
+ origin_d.errback(e)
+
+ # ... which should cause the lookups to fail similarly
+ self.assertIs(self.failureResultOf(d1, Exception).value, e)
+ self.assertIs(self.failureResultOf(d2, Exception).value, e)
+
+ # ... and the callbacks to have been, uh, called.
+ self.assertEqual(callbacks, {"d1", "d2"})
+
+ # ... leaving the cache empty
+ self.assertEqual(len(obj.fn.cache.cache), 0)
+
+ # and a second call should work as normal
+ obj.result = defer.succeed(100)
+ d3 = obj.fn(1)
+ self.assertEqual(self.successResultOf(d3), 100)
+ self.assertEqual(obj.call_count, 2)
+
def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
@@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d, SynapseError)
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
+ """More tests for @cached
+
+ The following is a set of tests that got lost in a different file for a while.
+
+ There are probably duplicates of the tests in DescriptorTestCase. Ideally the
+ duplicates would be removed and the two sets of classes combined.
+ """
+
+ @defer.inlineCallbacks
+ def test_passthrough(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ a = A()
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals((yield a.func("bar")), "bar")
+
+ @defer.inlineCallbacks
+ def test_hit(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ self.assertEquals((yield a.func("foo")), "foo")
+ self.assertEquals(callcount[0], 1)
+
+ @defer.inlineCallbacks
+ def test_invalidate(self):
+ callcount = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 1)
+
+ a.func.invalidate(("foo",))
+
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+
+ def test_invalidate_missing(self):
+ class A:
+ @cached()
+ def func(self, key):
+ return key
+
+ A().func.invalidate(("what",))
+
+ @defer.inlineCallbacks
+ def test_max_entries(self):
+ callcount = [0]
+
+ class A:
+ @cached(max_entries=10)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ a = A()
+
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertEquals(callcount[0], 12)
+
+ # There must have been at least 2 evictions, meaning if we calculate
+ # all 12 values again, we must get called at least 2 more times
+ for k in range(0, 12):
+ yield a.func(k)
+
+ self.assertTrue(
+ callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
+ )
+
+ def test_prefill(self):
+ callcount = [0]
+
+ d = defer.succeed(123)
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return d
+
+ a = A()
+
+ a.func.prefill(("foo",), 456)
+
+ self.assertEquals(a.func("foo").result, 456)
+ self.assertEquals(callcount[0], 0)
+
+ @defer.inlineCallbacks
+ def test_invalidate_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func.invalidate(("foo",))
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 1)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ @defer.inlineCallbacks
+ def test_eviction_context(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached(max_entries=2)
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ yield a.func2("foo")
+ yield a.func2("foo2")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func("foo3")
+
+ self.assertEquals(callcount[0], 3)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 4)
+ self.assertEquals(callcount2[0], 3)
+
+ @defer.inlineCallbacks
+ def test_double_get(self):
+ callcount = [0]
+ callcount2 = [0]
+
+ class A:
+ @cached()
+ def func(self, key):
+ callcount[0] += 1
+ return key
+
+ @cached(cache_context=True)
+ def func2(self, key, cache_context):
+ callcount2[0] += 1
+ return self.func(key, on_invalidate=cache_context.invalidate)
+
+ a = A()
+ a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 1)
+
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
+
+ yield a.func2("foo")
+ a.func2.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
+
+ self.assertEquals(callcount[0], 1)
+ self.assertEquals(callcount2[0], 2)
+
+ a.func.invalidate(("foo",))
+ self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
+ yield a.func("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 2)
+
+ yield a.func2("foo")
+
+ self.assertEquals(callcount[0], 2)
+ self.assertEquals(callcount2[0], 3)
+
+
class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 0adb2174af..a739a6aaaf 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -19,7 +19,8 @@ from mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
-from .. import unittest
+from tests import unittest
+from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
@@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEquals(cache.pop("key"), None)
def test_del_multi(self):
- cache = LruCache(4, 2, cache_type=TreeCache)
+ cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache.clear()
self.assertEquals(len(cache), 0)
+ @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
+ def test_special_size(self):
+ cache = LruCache(10, "mycache")
+ self.assertEqual(cache.max_size, 100)
+
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self):
@@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m2 = Mock()
m3 = Mock()
m4 = Mock()
- cache = LruCache(4, 2, cache_type=TreeCache)
+ cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
diff --git a/tox.ini b/tox.ini
index 4d132eff4c..6dcc439a40 100644
--- a/tox.ini
+++ b/tox.ini
@@ -158,12 +158,9 @@ commands=
coverage html
[testenv:mypy]
-skip_install = True
deps =
{[base]deps}
- mypy==0.782
- mypy-zope
-extras = all
+extras = all,mypy
commands = mypy
# To find all folders that pass mypy you run:
|