From 39230d217104f3cd7aba9065dc478f935ce1e614 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Mar 2020 14:45:33 +0000 Subject: Clean up some LoggingContext stuff (#7120) * Pull Sentinel out of LoggingContext ... and drop a few unnecessary references to it * Factor out LoggingContext.current_context move `current_context` and `set_context` out to top-level functions. Mostly this means that I can more easily trace what's actually referring to LoggingContext, but I think it's generally neater. * move copy-to-parent into `stop` this really just makes `start` and `stop` more symetric. It also means that it behaves correctly if you manually `set_log_context` rather than using the context manager. * Replace `LoggingContext.alive` with `finished` Turn `alive` into `finished` and make it a bit better defined. --- synapse/util/metrics.py | 4 ++-- synapse/util/patch_inline_callbacks.py | 36 +++++++++++++++++----------------- 2 files changed, 20 insertions(+), 20 deletions(-) (limited to 'synapse/util') diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 7b18455469..ec61e14423 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -21,7 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ class Measure(object): raise RuntimeError("Measure() objects cannot be re-used") self.start = self.clock.time() - parent_context = LoggingContext.current_context() + parent_context = current_context() self._logging_context = LoggingContext( "Measure[%s]" % (self.name,), parent_context ) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 3925927f9f..fdff195771 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -32,7 +32,7 @@ def do_patch(): Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context global _already_patched @@ -43,35 +43,35 @@ def do_patch(): def new_inline_callbacks(f): @functools.wraps(f) def wrapped(*args, **kwargs): - start_context = LoggingContext.current_context() + start_context = current_context() changes = [] # type: List[str] orig = orig_inline_callbacks(_check_yield_points(f, changes)) try: res = orig(*args, **kwargs) except Exception: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s changed context from %s to %s on exception" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) raise if not isinstance(res, Deferred) or res.called: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "Completed %s changed context from %s to %s" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -79,23 +79,23 @@ def do_patch(): raise Exception(err) return res - if LoggingContext.current_context() != LoggingContext.sentinel: + if current_context(): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % (f, LoggingContext.current_context(), start_context) + ) % (f, current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) def check_ctx(r): - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]): function """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context @functools.wraps(f) def check_yield_points_inner(*args, **kwargs): @@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]): last_yield_line_no = gen.gi_frame.f_lineno result = None # type: Any while True: - expected_context = LoggingContext.current_context() + expected_context = current_context() try: isFailure = isinstance(result, Failure) @@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]): else: d = gen.send(result) except (StopIteration, defer._DefGen_Return) as e: - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens when the context is lost sometime *after* the # final yield and returning. E.g. we forgot to yield on a # function that returns a deferred. @@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( f.__qualname__, expected_context, - LoggingContext.current_context(), + current_context(), f.__code__.co_filename, last_yield_line_no, ) @@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]): # This happens if we yield on a deferred that doesn't follow # the log context rules without wrapping in a `make_deferred_yieldable`. # We raise here as this should never happen. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): err = ( "%s yielded with context %s rather than sentinel," " yielded on line %d in %s" % ( frame.f_code.co_name, - LoggingContext.current_context(), + current_context(), frame.f_lineno, frame.f_code.co_filename, ) @@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]): except Exception as e: result = Failure(e) - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the @@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( frame.f_code.co_name, expected_context, - LoggingContext.current_context(), + current_context(), last_yield_line_no, frame.f_lineno, frame.f_code.co_filename, -- cgit 1.5.1 From 7966a1cde9d4b598faa06620424844f2b35c94af Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 30 Mar 2020 19:06:52 +0100 Subject: Rewrite prune_old_outbound_device_pokes for efficiency (#7159) make sure we clear out all but one update for the user --- changelog.d/7159.bugfix | 1 + synapse/handlers/federation.py | 25 +------- synapse/storage/data_stores/main/devices.py | 71 ++++++++++++++++++---- synapse/util/stringutils.py | 21 ++++++- tests/federation/test_federation_sender.py | 92 +++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 37 deletions(-) create mode 100644 changelog.d/7159.bugfix (limited to 'synapse/util') diff --git a/changelog.d/7159.bugfix b/changelog.d/7159.bugfix new file mode 100644 index 0000000000..1b341b127b --- /dev/null +++ b/changelog.d/7159.bugfix @@ -0,0 +1 @@ +Fix excessive CPU usage by `prune_old_outbound_device_pokes` job. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 38ab6a8fc3..c7aa7acf3b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server -from ._base import BaseHandler - logger = logging.getLogger(__name__) @@ -93,27 +93,6 @@ class _NewEventInfo: auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) -def shortstr(iterable, maxitems=5): - """If iterable has maxitems or fewer, return the stringification of a list - containing those items. - - Otherwise, return the stringification of a a list with the first maxitems items, - followed by "...". - - Args: - iterable (Iterable): iterable to truncate - maxitems (int): number of items to return before truncating - - Returns: - unicode - """ - - items = list(itertools.islice(iterable, maxitems + 1)) - if len(items) <= maxitems: - return str(items) - return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 2d47cfd131..3140e1b722 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -41,6 +41,7 @@ from synapse.util.caches.descriptors import ( cachedList, ) from synapse.util.iterutils import batch_iter +from synapse.util.stringutils import shortstr logger = logging.getLogger(__name__) @@ -1092,18 +1093,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self): + def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. We keep one entry per - (destination, user_id) tuple to ensure that the prev_ids remain correct - if the server does come back. + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. """ - yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + yesterday = self._clock.time_msec() - prune_age def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. select_sql = """ - SELECT destination, user_id, max(stream_id) as stream_id - FROM device_lists_outbound_pokes + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 GROUP BY destination, user_id HAVING min(ts) < ? AND count(*) > 1 """ @@ -1114,14 +1144,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not rows: return + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. delete_sql = """ DELETE FROM device_lists_outbound_pokes - WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) """ - - txn.executemany( - delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) - ) + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount # Since we've deleted unsent deltas, we need to remove the entry # of last successful sent so that the prev_ids are correctly set. @@ -1131,7 +1176,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """ txn.executemany(sql, ((row[0], row[1]) for row in rows)) - logger.info("Pruned %d device list outbound pokes", txn.rowcount) + logger.info("Pruned %d device list outbound pokes", count) return run_as_background_process( "prune_old_outbound_device_pokes", diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2c0dcb5208..6899bcb788 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -13,10 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import itertools import random import re import string +from collections import Iterable import six from six import PY2, PY3 @@ -126,3 +127,21 @@ def assert_valid_client_secret(client_secret): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) + + +def shortstr(iterable: Iterable, maxitems: int = 5) -> str: + """If iterable has maxitems or fewer, return the stringification of a list + containing those items. + + Otherwise, return the stringification of a a list with the first maxitems items, + followed by "...". + + Args: + iterable: iterable to truncate + maxitems: number of items to return before truncating + """ + + items = list(itertools.islice(iterable, maxitems + 1)) + if len(items) <= maxitems: + return str(items) + return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 7763b12159..a5fe5c6880 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -370,6 +370,98 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): devices = {edu["content"]["device_id"] for edu in self.edus} self.assertEqual({"D1", "D2", "D3"}, devices) + def test_prune_outbound_device_pokes1(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server has never been reachable. + """ + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + # create devices + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 4) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # there should be a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + + def test_prune_outbound_device_pokes2(self): + """If a destination is unreachable, and the updates are pruned, we should get + a single update. + + This case tests the behaviour when the server was reachable, but then goes + offline. + """ + + # create first device + u1 = self.register_user("user", "pass") + self.login("user", "pass", device_id="D1") + + # expect the update EDU + self.assertEqual(len(self.edus), 1) + self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) + + # now the server goes offline + mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn.side_effect = lambda t, cb: defer.fail("fail") + + self.login("user", "pass", device_id="D2") + self.login("user", "pass", device_id="D3") + + # delete them again + self.get_success( + self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) + ) + + self.assertGreaterEqual(mock_send_txn.call_count, 3) + + # run the prune job + self.reactor.advance(10) + self.get_success( + self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1) + ) + + # recover the server + mock_send_txn.side_effect = self.record_transaction + self.hs.get_federation_sender().send_device_messages("host2") + self.pump() + + # ... and we should get a single update for this user. + self.assertEqual(len(self.edus), 1) + edu = self.edus.pop(0) + self.assertEqual(edu["edu_type"], "m.device_list_update") + c = edu["content"] + + # synapse uses an empty prev_id list to indicate "needs a full resync". + self.assertEqual(c["prev_id"], []) + def check_device_update_edu( self, edu: JsonDict, -- cgit 1.5.1 From 0f8f02bc39cf1879780f82bdb8dd581588d14cca Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 20 Apr 2020 11:43:29 +0100 Subject: On catchup, process each row with its own stream id (#7286) Other parts of the code (such as the StreamChangeCache) assume that there will not be multiple changes with the same stream id. This code was introduced in #7024, and I hope this fixes #7206. --- changelog.d/7286.misc | 1 + synapse/replication/tcp/handler.py | 73 ++++++++++++++++++++++++++++-- synapse/util/caches/stream_change_cache.py | 3 ++ 3 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 changelog.d/7286.misc (limited to 'synapse/util') diff --git a/changelog.d/7286.misc b/changelog.d/7286.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7286.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 2f5a299141..e32e68e8c4 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,7 +15,18 @@ # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Set +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, +) from prometheus_client import Counter @@ -268,11 +279,14 @@ class ReplicationCommandHandler: missing_updates, ) = await stream.get_updates_since(current_token, cmd.token) - if updates: + # TODO: add some tests for this + + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], + cmd.stream_name, token, [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. @@ -404,3 +418,52 @@ class ReplicationCommandHandler: We need to check if the client is interested in the stream or not """ self.send_command(RdataCommand(stream_name, token, data)) + + +UpdateToken = TypeVar("UpdateToken") +UpdateRow = TypeVar("UpdateRow") + + +def _batch_updates( + updates: Iterable[Tuple[UpdateToken, UpdateRow]] +) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]: + """Collect stream updates with the same token together + + Given a series of updates returned by Stream.get_updates_since(), collects + the updates which share the same stream_id together. + + For example: + + [(1, a), (1, b), (2, c), (3, d), (3, e)] + + becomes: + + [ + (1, [a, b]), + (2, [c]), + (3, [d, e]), + ] + """ + + update_iter = iter(updates) + + first_update = next(update_iter, None) + if first_update is None: + # empty input + return + + current_batch_token = first_update[0] + current_batch = [first_update[1]] + + for token, row in update_iter: + if token != current_batch_token: + # different token to the previous row: flush the previous + # batch and start anew + yield current_batch_token, current_batch + current_batch_token = token + current_batch = [] + + current_batch.append(row) + + # flush the final batch + yield current_batch_token, current_batch diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 235f64049c..c61d36a82e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -126,6 +126,9 @@ class StreamChangeCache(object): """ assert type(stream_pos) is int + # FIXME: add a sanity check here that we are not overwriting existing + # data in self._cache + if stream_pos > self._earliest_known_stream_pos: old_pos = self._entity_to_key.get(entity, None) if old_pos is not None: -- cgit 1.5.1 From 13683a3a223a3f08b295973e7b8aa6e9299a2fa5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Apr 2020 13:45:40 +0100 Subject: Extend StreamChangeCache to support multiple entities per stream ID (#7303) First some background: StreamChangeCache is used to keep track of what "entities" have changed since a given stream ID. So for example, we might use it to keep track of when the last to-device message for a given user was received [1], and hence whether we need to pull any to-device messages from the database on a sync [2]. Now, it turns out that StreamChangeCache didn't support more than one thing being changed at a given stream_id (this was part of the problem with #7206). However, it's entirely valid to send to-device messages to more than one user at a time. As it turns out, this did in fact work, because *some* methods of StreamChangeCache coped ok with having multiple things changing on the same stream ID, and it seems we never actually use the methods which don't work on the stream change caches where we allow multiple changes at the same stream ID. But that feels horribly fragile, hence: let's update StreamChangeCache to properly support this, and add some typing and some more tests while we're at it. [1]: https://github.com/matrix-org/synapse/blob/release-v1.12.3/synapse/storage/data_stores/main/deviceinbox.py#L301 [2]: https://github.com/matrix-org/synapse/blob/release-v1.12.3/synapse/storage/data_stores/main/deviceinbox.py#L47-L51 --- changelog.d/7303.misc | 1 + stubs/sortedcontainers/__init__.pyi | 13 +++ stubs/sortedcontainers/sorteddict.pyi | 124 +++++++++++++++++++++++++++++ synapse/util/caches/stream_change_cache.py | 117 ++++++++++++++++----------- tests/util/test_stream_change_cache.py | 69 +++++++++++++--- tox.ini | 4 +- 6 files changed, 272 insertions(+), 56 deletions(-) create mode 100644 changelog.d/7303.misc create mode 100644 stubs/sortedcontainers/__init__.pyi create mode 100644 stubs/sortedcontainers/sorteddict.pyi (limited to 'synapse/util') diff --git a/changelog.d/7303.misc b/changelog.d/7303.misc new file mode 100644 index 0000000000..aa89c2b254 --- /dev/null +++ b/changelog.d/7303.misc @@ -0,0 +1 @@ +Fix StreamChangeCache to work with multiple entities changing on the same stream id. diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi new file mode 100644 index 0000000000..073b806d3c --- /dev/null +++ b/stubs/sortedcontainers/__init__.pyi @@ -0,0 +1,13 @@ +from .sorteddict import ( + SortedDict, + SortedKeysView, + SortedItemsView, + SortedValuesView, +) + +__all__ = [ + "SortedDict", + "SortedKeysView", + "SortedItemsView", + "SortedValuesView", +] diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi new file mode 100644 index 0000000000..68779f968e --- /dev/null +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -0,0 +1,124 @@ +# stub for SortedDict. This is a lightly edited copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/eea42df1f7bad2792e8da77335ff888f04b9e5ae/sortedcontainers/sorteddict.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterator, + Iterable, + ItemsView, + KeysView, + List, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Tuple, + Union, + ValuesView, + overload, +) + +_T = TypeVar("_T") +_S = TypeVar("_S") +_T_h = TypeVar("_T_h", bound=Hashable) +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. +_KT_co = TypeVar("_KT_co", covariant=True, bound=Hashable) +_VT_co = TypeVar("_VT_co", covariant=True) +_SD = TypeVar("_SD", bound=SortedDict) +_Key = Callable[[_T], Any] + +class SortedDict(Dict[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @overload + def __init__(self, __key: _Key[_KT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __map: Mapping[_KT, _VT], **kwargs: _VT + ) -> None: ... + @overload + def __init__( + self, __key: _Key[_KT], __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + @property + def key(self) -> Optional[_Key[_KT]]: ... + @property + def iloc(self) -> SortedKeysView[_KT]: ... + def clear(self) -> None: ... + def __delitem__(self, key: _KT) -> None: ... + def __iter__(self) -> Iterator[_KT]: ... + def __reversed__(self) -> Iterator[_KT]: ... + def __setitem__(self, key: _KT, value: _VT) -> None: ... + def _setitem(self, key: _KT, value: _VT) -> None: ... + def copy(self: _SD) -> _SD: ... + def __copy__(self: _SD) -> _SD: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + @classmethod + @overload + def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... + def keys(self) -> SortedKeysView[_KT]: ... + def items(self) -> SortedItemsView[_KT, _VT]: ... + def values(self) -> SortedValuesView[_VT]: ... + @overload + def pop(self, key: _KT) -> _VT: ... + @overload + def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ... + def popitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def peekitem(self, index: int = ...) -> Tuple[_KT, _VT]: ... + def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ... + @overload + def update(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def update(self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT) -> None: ... + @overload + def update(self, **kwargs: _VT) -> None: ... + def __reduce__( + self, + ) -> Tuple[ + Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], + ]: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_KT]: ... + def bisect_left(self, value: _KT) -> int: ... + def bisect_right(self, value: _KT) -> int: ... + +class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]): + @overload + def __getitem__(self, index: int) -> _KT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_KT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedItemsView( # type: ignore + ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]] +): + def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ... + @overload + def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ... + @overload + def __getitem__(self, index: slice) -> List[Tuple[_KT_co, _VT_co]]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + +class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]): + @overload + def __getitem__(self, index: int) -> _VT_co: ... + @overload + def __getitem__(self, index: slice) -> List[_VT_co]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c61d36a82e..38dc3f501e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Iterable, List, Mapping, Optional, Set from six import integer_types @@ -23,8 +24,11 @@ from synapse.util import caches logger = logging.getLogger(__name__) +# for now, assume all entities in the cache are strings +EntityType = str -class StreamChangeCache(object): + +class StreamChangeCache: """Keeps track of the stream positions of the latest change in a set of entities. Typically the entity will be a room or user id. @@ -34,10 +38,23 @@ class StreamChangeCache(object): old then the cache will simply return all given entities. """ - def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None): + def __init__( + self, + name: str, + current_stream_pos: int, + max_size=10000, + prefilled_cache: Optional[Mapping[EntityType, int]] = None, + ): self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) - self._entity_to_key = {} - self._cache = SortedDict() + self._entity_to_key = {} # type: Dict[EntityType, int] + + # map from stream id to the a set of entities which changed at that stream id. + self._cache = SortedDict() # type: SortedDict[int, Set[EntityType]] + + # the earliest stream_pos for which we can reliably answer + # get_all_entities_changed. In other words, one less than the earliest + # stream_pos for which we know _cache is valid. + # self._earliest_known_stream_pos = current_stream_pos self.name = name self.metrics = caches.register_cache("cache", self.name, self._cache) @@ -46,7 +63,7 @@ class StreamChangeCache(object): for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) - def has_entity_changed(self, entity, stream_pos): + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ assert type(stream_pos) in integer_types @@ -67,22 +84,17 @@ class StreamChangeCache(object): self.metrics.inc_hits() return False - def get_entities_changed(self, entities, stream_pos): + def get_entities_changed( + self, entities: Iterable[EntityType], stream_pos: int + ) -> Set[EntityType]: """ Returns subset of entities that have had new things since the given position. Entities unknown to the cache will be returned. If the position is too old it will just return the given list. """ - assert type(stream_pos) is int - - if stream_pos >= self._earliest_known_stream_pos: - changed_entities = { - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - } - - result = changed_entities.intersection(entities) - + changed_entities = self.get_all_entities_changed(stream_pos) + if changed_entities is not None: + result = set(changed_entities).intersection(entities) self.metrics.inc_hits() else: result = set(entities) @@ -90,13 +102,13 @@ class StreamChangeCache(object): return result - def has_any_entity_changed(self, stream_pos): + def has_any_entity_changed(self, stream_pos: int) -> bool: """Returns if any entity has changed """ assert type(stream_pos) is int if not self._cache: - # If we have no cache, nothing can have changed. + # If the cache is empty, nothing can have changed. return False if stream_pos >= self._earliest_known_stream_pos: @@ -106,45 +118,58 @@ class StreamChangeCache(object): self.metrics.inc_misses() return True - def get_all_entities_changed(self, stream_pos): - """Returns all entites that have had new things since the given + def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + """Returns all entities that have had new things since the given position. If the position is too old it will return None. + + Returns the entities in the order that they were changed. """ assert type(stream_pos) is int - if stream_pos >= self._earliest_known_stream_pos: - return [ - self._cache[k] - for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)) - ] - else: + if stream_pos < self._earliest_known_stream_pos: return None - def entity_has_changed(self, entity, stream_pos): + changed_entities = [] # type: List[EntityType] + + for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): + changed_entities.extend(self._cache[k]) + return changed_entities + + def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """Informs the cache that the entity has been changed at the given position. """ assert type(stream_pos) is int - # FIXME: add a sanity check here that we are not overwriting existing - # data in self._cache - - if stream_pos > self._earliest_known_stream_pos: - old_pos = self._entity_to_key.get(entity, None) - if old_pos is not None: - stream_pos = max(stream_pos, old_pos) - self._cache.pop(old_pos, None) - self._cache[stream_pos] = entity - self._entity_to_key[entity] = stream_pos - - while len(self._cache) > self._max_size: - k, r = self._cache.popitem(0) - self._earliest_known_stream_pos = max( - k, self._earliest_known_stream_pos - ) - self._entity_to_key.pop(r, None) - - def get_max_pos_of_last_change(self, entity): + if stream_pos <= self._earliest_known_stream_pos: + return + + old_pos = self._entity_to_key.get(entity, None) + if old_pos is not None: + if old_pos >= stream_pos: + # nothing to do + return + e = self._cache[old_pos] + e.remove(entity) + if not e: + # cache at this point is now empty + del self._cache[old_pos] + + e1 = self._cache.get(stream_pos) + if e1 is None: + e1 = self._cache[stream_pos] = set() + e1.add(entity) + self._entity_to_key[entity] = stream_pos + + # if the cache is too big, remove entries + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + del self._entity_to_key[entity] + + def get_max_pos_of_last_change(self, entity: EntityType) -> int: + """Returns an upper bound of the stream id of the last change to an entity. """ diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 72a9de5370..6857933540 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -28,18 +28,26 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 6) cache.entity_has_changed("bar@baz.net", 7) + # also test multiple things changing on the same stream ID + cache.entity_has_changed("user2@foo.com", 8) + cache.entity_has_changed("bar2@baz.net", 8) + # If it's been changed after that stream position, return True self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("bar2@baz.net", 4)) + self.assertTrue(cache.has_entity_changed("user2@foo.com", 4)) # If it's been changed at that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 8)) # If there's no changes after that stream position, return False self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + self.assertFalse(cache.has_entity_changed("user2@foo.com", 9)) # If the entity does not exist, return False. - self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + self.assertFalse(cache.has_entity_changed("not@here.website", 9)) # If we request before the stream cache's earliest known position, # return True, whether it's a known entity or not. @@ -47,7 +55,7 @@ class StreamChangeCacheTests(unittest.TestCase): self.assertTrue(cache.has_entity_changed("not@here.website", 0)) @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) - def test_has_entity_changed_pops_off_start(self): + def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. @@ -64,11 +72,20 @@ class StreamChangeCacheTests(unittest.TestCase): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) + # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) + self.assertEqual( + cache.get_all_entities_changed(2), ["user@elsewhere.org", "bar@baz.net"], + ) + self.assertIsNone(cache.get_all_entities_changed(1)) def test_get_all_entities_changed(self): """ @@ -80,18 +97,52 @@ class StreamChangeCacheTests(unittest.TestCase): cache.entity_has_changed("user@foo.com", 2) cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - self.assertEqual( - cache.get_all_entities_changed(1), - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], - ) - self.assertEqual( - cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] - ) + r = cache.get_all_entities_changed(1) + + # either of these are valid + ok1 = [ + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + "user@elsewhere.org", + ] + ok2 = [ + "user@foo.com", + "anotheruser@foo.com", + "bar@baz.net", + "user@elsewhere.org", + ] + self.assertTrue(r == ok1 or r == ok2) + + r = cache.get_all_entities_changed(2) + self.assertTrue(r == ok1[1:] or r == ok2[1:]) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) self.assertEqual(cache.get_all_entities_changed(0), None) + # ... later, things gest more updates + cache.entity_has_changed("user@foo.com", 5) + cache.entity_has_changed("bar@baz.net", 5) + cache.entity_has_changed("anotheruser@foo.com", 6) + + ok1 = [ + "user@elsewhere.org", + "user@foo.com", + "bar@baz.net", + "anotheruser@foo.com", + ] + ok2 = [ + "user@elsewhere.org", + "bar@baz.net", + "user@foo.com", + "anotheruser@foo.com", + ] + r = cache.get_all_entities_changed(3) + self.assertTrue(r == ok1 or r == ok2) + def test_has_any_entity_changed(self): """ StreamChangeCache.has_any_entity_changed will return True if any diff --git a/tox.ini b/tox.ini index 42b2d74891..31011d7436 100644 --- a/tox.ini +++ b/tox.ini @@ -202,7 +202,9 @@ commands = mypy \ synapse/spam_checker_api \ synapse/storage/engines \ synapse/storage/database.py \ - synapse/streams + synapse/streams \ + synapse/util/caches/stream_change_cache.py \ + tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: # -- cgit 1.5.1 From f9073893af82eec64b594dbcaef37c407a291c52 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 5 May 2020 17:07:59 +0100 Subject: Speed up fetching device lists changes in sync. Currently we copy `users_who_share_room` needlessly about three times, which is expensive when the set is large (which it can easily be). --- synapse/handlers/sync.py | 12 ++++++++---- synapse/storage/data_stores/main/devices.py | 4 ++-- synapse/util/caches/stream_change_cache.py | 19 +++++++++++++++---- 3 files changed, 25 insertions(+), 10 deletions(-) (limited to 'synapse/util') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4f76b7a743..00718d7f2d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1143,10 +1143,14 @@ class SyncHandler(object): user_id ) - tracked_users = set(users_who_share_room) - - # Always tell the user about their own devices - tracked_users.add(user_id) + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) + + tracked_users = users_who_share_room # Step 1a, check for changes in devices of users we share a room with users_that_have_changed = await self.store.get_users_whose_devices_changed( diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index ee3a2ab031..03f5141e6c 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -541,8 +541,8 @@ class DeviceWorkerStore(SQLBaseStore): # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - to_check = list( - self._device_list_stream_cache.get_entities_changed(user_ids, from_key) + to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key ) if not to_check: diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 38dc3f501e..e54f80d76e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,12 +14,13 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Mapping, Optional, Set +from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types from sortedcontainers import SortedDict +from synapse.types import Collection from synapse.util import caches logger = logging.getLogger(__name__) @@ -85,8 +86,8 @@ class StreamChangeCache: return False def get_entities_changed( - self, entities: Iterable[EntityType], stream_pos: int - ) -> Set[EntityType]: + self, entities: Collection[EntityType], stream_pos: int + ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ Returns subset of entities that have had new things since the given position. Entities unknown to the cache will be returned. If the @@ -94,7 +95,17 @@ class StreamChangeCache: """ changed_entities = self.get_all_entities_changed(stream_pos) if changed_entities is not None: - result = set(changed_entities).intersection(entities) + # We now do an intersection, trying to do so in the most efficient + # way possible (some of these sets are *large*). First check in the + # given iterable is already set that we can reuse, otherwise we + # create a set of the *smallest* of the two iterables and call + # `intersection(..)` on it (this can be twice as fast as the reverse). + if isinstance(entities, (set, frozenset)): + result = entities.intersection(changed_entities) + elif len(changed_entities) < len(entities): + result = set(changed_entities).intersection(entities) + else: + result = set(entities).intersection(changed_entities) self.metrics.inc_hits() else: result = set(entities) -- cgit 1.5.1 From 7cb8b4bc67042a39bd1b0e05df46089a2fce1955 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 12 May 2020 03:45:23 +1000 Subject: Allow configuration of Synapse's cache without using synctl or environment variables (#6391) --- changelog.d/6391.feature | 1 + docs/sample_config.yaml | 43 +++++- synapse/api/auth.py | 4 +- synapse/app/homeserver.py | 5 +- synapse/config/cache.py | 164 ++++++++++++++++++++++ synapse/config/database.py | 6 - synapse/config/homeserver.py | 2 + synapse/http/client.py | 6 +- synapse/metrics/_exposition.py | 12 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_rule_evaluator.py | 4 +- synapse/replication/slave/storage/client_ips.py | 3 +- synapse/state/__init__.py | 4 +- synapse/storage/data_stores/main/client_ips.py | 3 +- synapse/storage/data_stores/main/events_worker.py | 5 +- synapse/storage/data_stores/state/store.py | 6 +- synapse/util/caches/__init__.py | 144 ++++++++++--------- synapse/util/caches/descriptors.py | 36 ++++- synapse/util/caches/expiringcache.py | 29 +++- synapse/util/caches/lrucache.py | 52 +++++-- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/stream_change_cache.py | 33 ++++- synapse/util/caches/ttlcache.py | 2 +- tests/config/test_cache.py | 127 +++++++++++++++++ tests/storage/test__base.py | 8 +- tests/storage/test_appservice.py | 10 +- tests/storage/test_base.py | 3 +- tests/test_metrics.py | 34 +++++ tests/util/test_expiring_cache.py | 2 +- tests/util/test_lrucache.py | 6 +- tests/util/test_stream_change_cache.py | 5 +- tests/utils.py | 1 + 32 files changed, 620 insertions(+), 146 deletions(-) create mode 100644 changelog.d/6391.feature create mode 100644 synapse/config/cache.py create mode 100644 tests/config/test_cache.py (limited to 'synapse/util') diff --git a/changelog.d/6391.feature b/changelog.d/6391.feature new file mode 100644 index 0000000000..f123426e23 --- /dev/null +++ b/changelog.d/6391.feature @@ -0,0 +1 @@ +Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 5abeaf519b..8a8415b9a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -603,6 +603,45 @@ acme: +## Caching ## + +# Caching can be configured through the following options. +# +# A cache 'factor' is a multiplier that can be applied to each of +# Synapse's caches in order to increase or decrease the maximum +# number of entries that can be stored. + +# The number of events to cache in memory. Not affected by +# caches.global_factor. +# +#event_cache_size: 10K + +caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + ## Database ## # The 'database' setting defines the database that synapse uses to store all of @@ -646,10 +685,6 @@ database: args: database: DATADIR/homeserver.db -# Number of events to cache in memory. -# -#event_cache_size: 10K - ## Logging ## diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1ad5ff9410..e009b1a760 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap, UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -73,7 +73,7 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) + self.token_cache = LruCache(10000) register_cache("cache", "token_cache", self.token_cache) self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bc8695d8dd..d7f337e586 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.module_loader import load_module @@ -516,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size # # Performance statistics diff --git a/synapse/config/cache.py b/synapse/config/cache.py new file mode 100644 index 0000000000..91036a012e --- /dev/null +++ b/synapse/config/cache.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Dict + +from ._base import Config, ConfigError + +# The prefix for all cache factor-related environment variables +_CACHES = {} +_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" +_DEFAULT_FACTOR_SIZE = 0.5 +_DEFAULT_EVENT_CACHE_SIZE = "10K" + + +class CacheProperties(object): + def __init__(self): + # The default factor size for all caches + self.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + self.resize_all_caches_func = None + + +properties = CacheProperties() + + +def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): + """Register a cache that's size can dynamically change + + Args: + cache_name: A reference to the cache + cache_resize_callback: A callback function that will be ran whenever + the cache needs to be resized + """ + _CACHES[cache_name.lower()] = cache_resize_callback + + # Ensure all loaded caches are sized appropriately + # + # This method should only run once the config has been read, + # as it uses values read from it + if properties.resize_all_caches_func: + properties.resize_all_caches_func() + + +class CacheConfig(Config): + section = "caches" + _environ = os.environ + + @staticmethod + def reset(): + """Resets the caches to their defaults. Used for tests.""" + properties.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + properties.resize_all_caches_func = None + _CACHES.clear() + + def generate_config_section(self, **kwargs): + return """\ + ## Caching ## + + # Caching can be configured through the following options. + # + # A cache 'factor' is a multiplier that can be applied to each of + # Synapse's caches in order to increase or decrease the maximum + # number of entries that can be stored. + + # The number of events to cache in memory. Not affected by + # caches.global_factor. + # + #event_cache_size: 10K + + caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + """ + + def read_config(self, config, **kwargs): + self.event_cache_size = self.parse_size( + config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) + ) + self.cache_factors = {} # type: Dict[str, float] + + cache_config = config.get("caches") or {} + self.global_factor = cache_config.get( + "global_factor", properties.default_factor_size + ) + if not isinstance(self.global_factor, (int, float)): + raise ConfigError("caches.global_factor must be a number.") + + # Set the global one so that it's reflected in new caches + properties.default_factor_size = self.global_factor + + # Load cache factors from the config + individual_factors = cache_config.get("per_cache_factors") or {} + if not isinstance(individual_factors, dict): + raise ConfigError("caches.per_cache_factors must be a dictionary") + + # Override factors from environment if necessary + individual_factors.update( + { + key[len(_CACHE_PREFIX) + 1 :].lower(): float(val) + for key, val in self._environ.items() + if key.startswith(_CACHE_PREFIX + "_") + } + ) + + for cache, factor in individual_factors.items(): + if not isinstance(factor, (int, float)): + raise ConfigError( + "caches.per_cache_factors.%s must be a number" % (cache.lower(),) + ) + self.cache_factors[cache.lower()] = factor + + # Resize all caches (if necessary) with the new factors we've loaded + self.resize_all_caches() + + # Store this function so that it can be called from other classes without + # needing an instance of Config + properties.resize_all_caches_func = self.resize_all_caches + + def resize_all_caches(self): + """Ensure all cache sizes are up to date + + For each cache, run the mapped callback function with either + a specific cache factor or the default, global one. + """ + for cache_name, callback in _CACHES.items(): + new_factor = self.cache_factors.get(cache_name, self.global_factor) + callback(new_factor) diff --git a/synapse/config/database.py b/synapse/config/database.py index 5b662d1b01..1064c2697b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -68,10 +68,6 @@ database: name: sqlite3 args: database: %(database_path)s - -# Number of events to cache in memory. -# -#event_cache_size: 10K """ @@ -116,8 +112,6 @@ class DatabaseConfig(Config): self.databases = [] def read_config(self, config, **kwargs): - self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - # We *experimentally* support specifying multiple databases via the # `databases` key. This is a map from a label to database config in the # same format as the `database` config option, plus an extra diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 996d3e6bf7..2c7b3a699f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig from .consent_config import ConsentConfig @@ -55,6 +56,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + CacheConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, diff --git a/synapse/http/client.py b/synapse/http/client.py index 58eb47c69c..3cef747a4d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred -from synapse.util.caches import CACHE_SIZE_FACTOR logger = logging.getLogger(__name__) @@ -241,7 +240,10 @@ class SimpleHttpClient(object): # tends to do so in batches, so we need to allow the pool to keep # lots of idle connections around. pool = HTTPConnectionPool(self.reactor) - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + # XXX: The justification for using the cache factor here is that larger instances + # will need both more cache and more connections. + # Still, this should probably be a separate dial + pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 self.agent = ProxyAgent( diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index a248103191..ab7f948ed4 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -33,6 +33,8 @@ from prometheus_client import REGISTRY from twisted.web.resource import Resource +from synapse.util import caches + try: from prometheus_client.samples import Sample except ImportError: @@ -103,13 +105,15 @@ def nameify_sample(sample): def generate_latest(registry, emit_help=False): - output = [] - for metric in registry.collect(): + # Trigger the cache metrics to be rescraped, which updates the common + # metrics but do not produce metrics themselves + for collector in caches.collectors_by_name.values(): + collector.collect() - if metric.name.startswith("__unused"): - continue + output = [] + for metric in registry.collect(): if not metric.samples: # No samples, don't bother. continue diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 433ca2f416..e75d964ac8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache( "cache", "push_rules_delta_state_cache_metric", cache=[], # Meaningless size, as this isn't a cache that stores values + resizable=False, ) @@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", - cache=[], # Meaningless size, as this isn't a cache that stores values + cache=[], # Meaningless size, as this isn't a cache that stores values, + resizable=False, ) @defer.inlineCallbacks diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4cd702b5fa..11032491af 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -22,7 +22,7 @@ from six import string_types from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object): # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +regex_cache = LruCache(50000) register_cache("cache", "regex_push_cache", regex_cache) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index fbf996e33a..1a38f53dfb 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,6 @@ from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.database import Database -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore @@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4afefc6b1d..2fa529fcd0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -35,7 +35,6 @@ from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import StateMap from synapse.util.async_helpers import Linearizer -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -53,7 +52,6 @@ state_groups_histogram = Histogram( KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -447,7 +445,7 @@ class StateResolutionHandler(object): self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, - max_len=SIZE_OF_CACHE, + max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 92bc06919b..71f8d43a76 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database, make_tuple_comparison_clause -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) super(ClientIpStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 73df6b33ba..b8c1bbdf99 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -75,7 +75,10 @@ class EventsWorkerStore(SQLBaseStore): super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 57a5267663..f3ad1e4369 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.types import StateMap -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), + 50000, ) self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), + "*stateGroupMembersCache*", 500000, ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index da5077b471..4b8a0c7a8f 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,27 +15,17 @@ # limitations under the License. import logging -import os -from typing import Dict +from typing import Callable, Dict, Optional import six from six.moves import intern -from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily - -logger = logging.getLogger(__name__) - -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) +import attr +from prometheus_client.core import Gauge +from synapse.config.cache import add_resizable_cache -def get_cache_factor_for(cache_name): - env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() - factor = os.environ.get(env_var) - if factor: - return float(factor) - - return CACHE_SIZE_FACTOR - +logger = logging.getLogger(__name__) caches_by_name = {} collectors_by_name = {} # type: Dict @@ -44,6 +34,7 @@ cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) +cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) @@ -53,67 +44,82 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -def register_cache(cache_type, cache_name, cache, collect_callback=None): - """Register a cache object for metric collection. +@attr.s +class CacheMetric(object): + + _cache = attr.ib() + _cache_type = attr.ib(type=str) + _cache_name = attr.ib(type=str) + _collect_callback = attr.ib(type=Optional[Callable]) + + hits = attr.ib(default=0) + misses = attr.ib(default=0) + evicted_size = attr.ib(default=0) + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + try: + if self._cache_type == "response_cache": + response_cache_size.labels(self._cache_name).set(len(self._cache)) + response_cache_hits.labels(self._cache_name).set(self.hits) + response_cache_evicted.labels(self._cache_name).set(self.evicted_size) + response_cache_total.labels(self._cache_name).set( + self.hits + self.misses + ) + else: + cache_size.labels(self._cache_name).set(len(self._cache)) + cache_hits.labels(self._cache_name).set(self.hits) + cache_evicted.labels(self._cache_name).set(self.evicted_size) + cache_total.labels(self._cache_name).set(self.hits + self.misses) + if getattr(self._cache, "max_size", None): + cache_max_size.labels(self._cache_name).set(self._cache.max_size) + if self._collect_callback: + self._collect_callback() + except Exception as e: + logger.warning("Error calculating metrics for %s: %s", self._cache_name, e) + raise + + +def register_cache( + cache_type: str, + cache_name: str, + cache, + collect_callback: Optional[Callable] = None, + resizable: bool = True, + resize_callback: Optional[Callable] = None, +) -> CacheMetric: + """Register a cache object for metric collection and resizing. Args: - cache_type (str): - cache_name (str): name of the cache - cache (object): cache itself - collect_callback (callable|None): if not None, a function which is called during - metric collection to update additional metrics. + cache_type + cache_name: name of the cache + cache: cache itself + collect_callback: If given, a function which is called during metric + collection to update additional metrics. + resizable: Whether this cache supports being resized. + resize_callback: A function which can be called to resize the cache. Returns: CacheMetric: an object which provides inc_{hits,misses,evictions} methods """ + if resizable: + if not resize_callback: + resize_callback = getattr(cache, "set_cache_factor") + add_resizable_cache(cache_name, resize_callback) - # Check if the metric is already registered. Unregister it, if so. - # This usually happens during tests, as at runtime these caches are - # effectively singletons. + metric = CacheMetric(cache, cache_type, cache_name, collect_callback) metric_name = "cache_%s_%s" % (cache_type, cache_name) - if metric_name in collectors_by_name.keys(): - REGISTRY.unregister(collectors_by_name[metric_name]) - - class CacheMetric(object): - - hits = 0 - misses = 0 - evicted_size = 0 - - def inc_hits(self): - self.hits += 1 - - def inc_misses(self): - self.misses += 1 - - def inc_evictions(self, size=1): - self.evicted_size += size - - def describe(self): - return [] - - def collect(self): - try: - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) - if collect_callback: - collect_callback() - except Exception as e: - logger.warning("Error calculating metrics for %s: %s", cache_name, e) - raise - - yield GaugeMetricFamily("__unused", "") - - metric = CacheMetric() - REGISTRY.register(metric) caches_by_name[cache_name] = cache collectors_by_name[metric_name] = metric return metric diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2e8f6543e5..cd48262420 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import inspect import logging @@ -30,7 +31,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -81,7 +81,6 @@ class CacheEntry(object): class Cache(object): __slots__ = ( "cache", - "max_entries", "name", "keylen", "thread", @@ -89,7 +88,29 @@ class Cache(object): "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + + Returns: + Cache + """ cache_type = TreeCache if tree else dict self._pending_deferred_cache = cache_type() @@ -99,6 +120,7 @@ class Cache(object): cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, ) self.name = name @@ -111,6 +133,10 @@ class Cache(object): collect_callback=self._metrics_collection_callback, ) + @property + def max_entries(self): + return self.cache.max_size + def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) @@ -370,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, ) - max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) - self.max_entries = max_entries self.tree = tree self.iterable = iterable - def __get__(self, obj, objtype=None): + def __get__(self, obj, owner): cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index cddf1ed515..2726b67b6d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -18,6 +18,7 @@ from collections import OrderedDict from six import iteritems, itervalues +from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -51,15 +52,16 @@ class ExpiringCache(object): an item on access. Defaults to False. iterable (bool): If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. - """ self._cache_name = cache_name + self._original_max_size = max_len + + self._max_size = int(max_len * cache_config.properties.default_factor_size) + self._clock = clock - self._max_len = max_len self._expiry_ms = expiry_ms - self._reset_expiry_on_get = reset_expiry_on_get self._cache = OrderedDict() @@ -82,9 +84,11 @@ class ExpiringCache(object): def __setitem__(self, key, value): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + self.evict() + def evict(self): # Evict if there are now too many items - while self._max_len and len(self) > self._max_len: + while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: self.metrics.inc_evictions(len(value.value)) @@ -170,6 +174,23 @@ class ExpiringCache(object): else: return len(self._cache) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self._max_size: + self._max_size = new_size + self.evict() + return True + return False + class _CacheEntry(object): __slots__ = ["time", "value"] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1536cb64f3..29fabac3cd 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading from functools import wraps +from typing import Callable, Optional, Type, Union +from synapse.config import cache as cache_config from synapse.util.caches.treecache import TreeCache @@ -52,17 +53,18 @@ class LruCache(object): def __init__( self, - max_size, - keylen=1, - cache_type=dict, - size_callback=None, - evicted_callback=None, + max_size: int, + keylen: int = 1, + cache_type: Type[Union[dict, TreeCache]] = dict, + size_callback: Optional[Callable] = None, + evicted_callback: Optional[Callable] = None, + apply_cache_factor_from_config: bool = True, ): """ Args: - max_size (int): + max_size: The maximum amount of entries the cache can hold - keylen (int): + keylen: The length of the tuple used as the cache key cache_type (type): type of underlying cache to be used. Typically one of dict @@ -73,9 +75,23 @@ class LruCache(object): evicted_callback (func(int)|None): if not None, called on eviction with the size of the evicted entry + + apply_cache_factor_from_config (bool): If true, `max_size` will be + multiplied by a cache factor derived from the homeserver config """ cache = cache_type() self.cache = cache # Used for introspection. + + # Save the original max size, and apply the default size factor. + self._original_max_size = max_size + # We previously didn't apply the cache factor here, and as such some caches were + # not affected by the global cache factor. Add an option here to disable applying + # the cache factor when a cache is created + if apply_cache_factor_from_config: + self.max_size = int(max_size * cache_config.properties.default_factor_size) + else: + self.max_size = int(max_size) + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -83,7 +99,7 @@ class LruCache(object): lock = threading.Lock() def evict(): - while cache_len() > max_size: + while cache_len() > self.max_size: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) @@ -236,6 +252,7 @@ class LruCache(object): return key in cache self.sentinel = object() + self._on_resize = evict self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -266,3 +283,20 @@ class LruCache(object): def __contains__(self, key): return self.contains(key) + + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self.max_size: + self.max_size = new_size + self._on_resize() + return True + return False diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index b68f9fe0d4..a6c60888e5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -38,7 +38,7 @@ class ResponseCache(object): self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache("response_cache", name, self) + self._metrics = register_cache("response_cache", name, self, resizable=False) def size(self): return len(self.pending_result_cache) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index e54f80d76e..2a161bf244 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types @@ -46,7 +47,8 @@ class StreamChangeCache: max_size=10000, prefilled_cache: Optional[Mapping[EntityType, int]] = None, ): - self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) + self._original_max_size = max_size + self._max_size = math.floor(max_size) self._entity_to_key = {} # type: Dict[EntityType, int] # map from stream id to the a set of entities which changed at that stream id. @@ -58,12 +60,31 @@ class StreamChangeCache: # self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = caches.register_cache("cache", self.name, self._cache) + self.metrics = caches.register_cache( + "cache", self.name, self._cache, resize_callback=self.set_cache_factor + ) if prefilled_cache: for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = math.floor(self._original_max_size * factor) + if new_size != self._max_size: + self.max_size = new_size + self._evict() + return True + return False + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ @@ -171,6 +192,7 @@ class StreamChangeCache: e1 = self._cache[stream_pos] = set() e1.add(entity) self._entity_to_key[entity] = stream_pos + self._evict() # if the cache is too big, remove entries while len(self._cache) > self._max_size: @@ -179,6 +201,13 @@ class StreamChangeCache: for entity in r: del self._entity_to_key[entity] + def _evict(self): + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + self._entity_to_key.pop(entity, None) + def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 99646c7cf0..6437aa907e 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -38,7 +38,7 @@ class TTLCache(object): self._timer = timer - self._metrics = register_cache("ttl", cache_name, self) + self._metrics = register_cache("ttl", cache_name, self, resizable=False) def set(self, key, value, ttl): """Add/update an entry in the cache diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py new file mode 100644 index 0000000000..2920279125 --- /dev/null +++ b/tests/config/test_cache.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, RootConfig +from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.util.caches.lrucache import LruCache + +from tests.unittest import TestCase + + +class FakeServer(Config): + section = "server" + + +class TestConfig(RootConfig): + config_classes = [FakeServer, CacheConfig] + + +class CacheConfigTests(TestCase): + def setUp(self): + # Reset caches before each test + TestConfig().caches.reset() + + def test_individual_caches_from_environ(self): + """ + Individual cache factors will be loaded from the environment. + """ + config = {} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_NOT_CACHE": "BLAH", + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + + def test_config_overrides_environ(self): + """ + Individual cache factors defined in the environment will take precedence + over those in the config. + """ + config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_CACHE_FACTOR_FOO": 1, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual( + dict(t.caches.cache_factors), + {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, + ) + + def test_individual_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized once the config + is loaded. + """ + cache = LruCache(100) + + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"per_cache_factors": {"foo": 3}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 300) + + def test_individual_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the per_cache_factor if + there is one. + """ + config = {"caches": {"per_cache_factors": {"foo": 2}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 200) + + def test_global_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized to the new + default cache size once the config is loaded. + """ + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"global_factor": 4}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 400) + + def test_global_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the global factor if there + is no per-cache factor. + """ + config = {"caches": {"global_factor": 1.5}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 150) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index e37260a820..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 31710949a8..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_list = [ @@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] database = hs.get_datastores().databases[0] @@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: @@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cdee0a9e60..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 + config.caches = Mock() + config.caches.event_cache_size = 1 hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 270f853d60..f5f63d8ed6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.descriptors import Cache from tests import unittest @@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache = Cache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 50bc7702d2..49ffeebd0e 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -21,7 +21,7 @@ from tests.utils import MockClock from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 786947375d..0adb2174af 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache from .. import unittest -class LruCacheTestCase(unittest.TestCase): +class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(len(cache), 0) -class LruCacheCallbacksTestCase(unittest.TestCase): +class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): m = Mock() cache = LruCache(1) @@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m3.call_count, 1) -class LruCacheSizedTestCase(unittest.TestCase): +class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 6857933540..13b753e367 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,11 +1,9 @@ -from mock import patch - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest -class StreamChangeCacheTests(unittest.TestCase): +class StreamChangeCacheTests(unittest.HomeserverTestCase): """ Tests for StreamChangeCache. """ @@ -54,7 +52,6 @@ class StreamChangeCacheTests(unittest.TestCase): self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and diff --git a/tests/utils.py b/tests/utils.py index f9be62b499..59c020a051 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,6 +167,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, + "caches": {"global_factor": 1}, } if parse: -- cgit 1.5.1 From 56b66db78a3a6a22f65b219f4dc12899111f742b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 May 2020 13:24:01 -0400 Subject: Strictly enforce canonicaljson requirements in a new room version (#7381) --- changelog.d/7381.bugfix | 1 + synapse/api/room_versions.py | 24 ++++++++++++- synapse/events/utils.py | 35 +++++++++++++++++- synapse/events/validator.py | 7 ++++ synapse/federation/federation_base.py | 6 +++- synapse/util/frozenutils.py | 2 +- tests/handlers/test_federation.py | 67 ++++++++++++++++++++++++++++++++++- 7 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 changelog.d/7381.bugfix (limited to 'synapse/util') diff --git a/changelog.d/7381.bugfix b/changelog.d/7381.bugfix new file mode 100644 index 0000000000..e5f93571dc --- /dev/null +++ b/changelog.d/7381.bugfix @@ -0,0 +1 @@ +Add an experimental room version which strictly adheres to the canonical JSON specification. diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index af3612ed61..0901afb900 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -59,7 +59,11 @@ class RoomVersion(object): # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool) - + # Strictly enforce canonicaljson, do not allow: + # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] + # * Floats + # * NaN, Infinity, -Infinity + strict_canonicaljson = attr.ib(type=bool) # bool: MSC2209: Check 'notifications' key while verifying # m.room.power_levels auth rules. limit_notifications_power_levels = attr.ib(type=bool) @@ -73,6 +77,7 @@ class RoomVersions(object): StateResolutionVersions.V1, enforce_key_validity=False, special_case_aliases_auth=True, + strict_canonicaljson=False, limit_notifications_power_levels=False, ) V2 = RoomVersion( @@ -82,6 +87,7 @@ class RoomVersions(object): StateResolutionVersions.V2, enforce_key_validity=False, special_case_aliases_auth=True, + strict_canonicaljson=False, limit_notifications_power_levels=False, ) V3 = RoomVersion( @@ -91,6 +97,7 @@ class RoomVersions(object): StateResolutionVersions.V2, enforce_key_validity=False, special_case_aliases_auth=True, + strict_canonicaljson=False, limit_notifications_power_levels=False, ) V4 = RoomVersion( @@ -100,6 +107,7 @@ class RoomVersions(object): StateResolutionVersions.V2, enforce_key_validity=False, special_case_aliases_auth=True, + strict_canonicaljson=False, limit_notifications_power_levels=False, ) V5 = RoomVersion( @@ -109,6 +117,7 @@ class RoomVersions(object): StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=True, + strict_canonicaljson=False, limit_notifications_power_levels=False, ) MSC2432_DEV = RoomVersion( @@ -118,6 +127,17 @@ class RoomVersions(object): StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, + strict_canonicaljson=False, + limit_notifications_power_levels=False, + ) + STRICT_CANONICALJSON = RoomVersion( + "org.matrix.strict_canonicaljson", + RoomDisposition.UNSTABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=True, + strict_canonicaljson=True, limit_notifications_power_levels=False, ) MSC2209_DEV = RoomVersion( @@ -127,6 +147,7 @@ class RoomVersions(object): StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=True, + strict_canonicaljson=False, limit_notifications_power_levels=True, ) @@ -140,6 +161,7 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V4, RoomVersions.V5, RoomVersions.MSC2432_DEV, + RoomVersions.STRICT_CANONICALJSON, RoomVersions.MSC2209_DEV, ) } # type: Dict[str, RoomVersion] diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b75b097e5e..dd340be9a7 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,7 +14,7 @@ # limitations under the License. import collections import re -from typing import Mapping, Union +from typing import Any, Mapping, Union from six import string_types @@ -23,6 +23,7 @@ from frozendict import frozendict from twisted.internet import defer from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results @@ -449,3 +450,35 @@ def copy_power_levels_contents( raise TypeError("Invalid power_levels value for %s: %r" % (k, v)) return power_levels + + +def validate_canonicaljson(value: Any): + """ + Ensure that the JSON object is valid according to the rules of canonical JSON. + + See the appendix section 3.1: Canonical JSON. + + This rejects JSON that has: + * An integer outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] + * Floats + * NaN, Infinity, -Infinity + """ + if isinstance(value, int): + if value <= -(2 ** 53) or 2 ** 53 <= value: + raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) + + elif isinstance(value, float): + # Note that Infinity, -Infinity, and NaN are also considered floats. + raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON) + + elif isinstance(value, (dict, frozendict)): + for v in value.values(): + validate_canonicaljson(v) + + elif isinstance(value, (list, tuple)): + for i in value: + validate_canonicaljson(i) + + elif not isinstance(value, (bool, str)) and value is not None: + # Other potential JSON values (bool, None, str) are safe. + raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 9b90c9ce04..b001c64bb4 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -18,6 +18,7 @@ from six import integer_types, string_types from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions +from synapse.events.utils import validate_canonicaljson from synapse.types import EventID, RoomID, UserID @@ -55,6 +56,12 @@ class EventValidator(object): if not isinstance(getattr(event, s), string_types): raise SynapseError(400, "'%s' not a string type" % (s,)) + # Depending on the room version, ensure the data is spec compliant JSON. + if event.room_version.strict_canonicaljson: + # Note that only the client controlled portion of the event is + # checked, since we trust the portions of the event we created. + validate_canonicaljson(event.content) + if event.type == EventTypes.Aliases: if "aliases" in event.content: for alias in event.content["aliases"]: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 4b115aac04..c0012c6872 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -29,7 +29,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.keyring import Keyring from synapse.events import EventBase, make_event_from_dict -from synapse.events.utils import prune_event +from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( PreserveLoggingContext, @@ -302,6 +302,10 @@ def event_from_pdu_json( elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) + # Validate that the JSON conforms to the specification. + if room_version.strict_canonicaljson: + validate_canonicaljson(pdu_json) + event = make_event_from_dict(pdu_json, room_version) event.internal_metadata.outlier = outlier diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index f2ccd5e7c6..9815bb8667 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -65,5 +65,5 @@ def _handle_frozendict(obj): ) -# A JSONEncoder which is capable of encoding frozendics without barfing +# A JSONEncoder which is capable of encoding frozendicts without barfing frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 132e35651d..dfef58e704 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from unittest import TestCase from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin @@ -207,3 +210,65 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) return join_event + + +class EventFromPduTestCase(TestCase): + def test_valid_json(self): + """Valid JSON should be turned into an event.""" + ev = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"bool": True, "null": None, "int": 1, "str": "foobar"}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.STRICT_CANONICALJSON, + ) + + self.assertIsInstance(ev, EventBase) + + def test_invalid_numbers(self): + """Invalid values for an integer should be rejected, all floats should be rejected.""" + for value in [ + -(2 ** 53), + 2 ** 53, + 1.0, + float("inf"), + float("-inf"), + float("nan"), + ]: + with self.assertRaises(SynapseError): + event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"foo": value}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.STRICT_CANONICALJSON, + ) + + def test_invalid_nested(self): + """List and dictionaries are recursively searched.""" + with self.assertRaises(SynapseError): + event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"foo": [{"bar": 2 ** 56}]}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.STRICT_CANONICALJSON, + ) -- cgit 1.5.1 From 08fa96f03037178620f5f0dd609fac52fbf7f2d1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:07:24 +0100 Subject: Remove `exception_to_unicode` this is a no-op on python 3. --- synapse/storage/database.py | 15 +++------------ synapse/util/stringutils.py | 36 ------------------------------------ 2 files changed, 3 insertions(+), 48 deletions(-) (limited to 'synapse/util') diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c3d0863429..9947dbce77 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor from synapse.types import Collection -from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -424,20 +423,14 @@ class Database(object): # This can happen if the database disappears mid # transaction. logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, + "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) + logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: @@ -449,9 +442,7 @@ class Database(object): conn.rollback() except self.engine.module.Error as e1: logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), + "[TXN EROLL] {%s} %s", name, e1, ) continue raise diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 6899bcb788..2cfa5cf721 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -85,42 +85,6 @@ def to_ascii(s): return s -def exception_to_unicode(e): - """Helper function to extract the text of an exception as a unicode string - - Args: - e (Exception): exception to be stringified - - Returns: - unicode - """ - # urgh, this is a mess. The basic problem here is that psycopg2 constructs its - # exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will - # then produce the raw byte sequence. Under Python 2, this will then cause another - # error if it gets mixed with a `unicode` object, as per - # https://github.com/matrix-org/synapse/issues/4252 - - # First of all, if we're under python3, everything is fine because it will sort this - # nonsense out for us. - if not PY2: - return str(e) - - # otherwise let's have a stab at decoding the exception message. We'll circumvent - # Exception.__str__(), which would explode if someone raised Exception(u'non-ascii') - # and instead look at what is in the args member. - - if len(e.args) == 0: - return "" - elif len(e.args) > 1: - return six.text_type(repr(e.args)) - - msg = e.args[0] - if isinstance(msg, bytes): - return msg.decode("utf-8", errors="replace") - else: - return msg - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: -- cgit 1.5.1 From 65902e08c3f4449de9baa4e6466f126585f688b3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:12:03 +0100 Subject: remove to_ascii this is a no-op on python 3. --- synapse/storage/data_stores/main/roommember.py | 25 ++++++++++--------------- synapse/storage/data_stores/main/state.py | 5 +---- synapse/util/stringutils.py | 20 +------------------- 3 files changed, 12 insertions(+), 38 deletions(-) (limited to 'synapse/util') diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 48810a3e91..1e9c850152 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.metrics import Measure -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -179,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] + return [r[0] for r in txn] @cached(max_entries=100000) def get_room_summary(self, room_id): @@ -223,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (room_id,)) res = {} for count, membership in txn: - summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + summary = res.setdefault(membership, MemberSummary([], count)) # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent @@ -255,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) for user_id, membership, event_id in txn: - summary = res[to_ascii(membership)] + summary = res[membership] # we will always have a summary for this membership type at this # point given the summary currently contains the counts. members = summary.members - members.append((to_ascii(user_id), to_ascii(event_id))) + members.append((user_id, event_id)) return res @@ -584,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry: if ev_entry.event.membership == Membership.JOIN: - users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( - display_name=to_ascii( - ev_entry.event.content.get("displayname", None) - ), - avatar_url=to_ascii( - ev_entry.event.content.get("avatar_url", None) - ), + users_in_room[ev_entry.event.state_key] = ProfileInfo( + display_name=ev_entry.event.content.get("displayname", None), + avatar_url=ev_entry.event.content.get("avatar_url", None), ) else: missing_member_event_ids.append(event_id) @@ -604,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[to_ascii(event.state_key)] = ProfileInfo( - display_name=to_ascii(event.content.get("displayname", None)), - avatar_url=to_ascii(event.content.get("avatar_url", None)), + users_in_room[event.state_key] = ProfileInfo( + display_name=event.content.get("displayname", None), + avatar_url=event.content.get("avatar_url", None), ) return users_in_room diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 21052fcc7a..347cc50778 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -29,7 +29,6 @@ from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -185,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): (room_id,), ) - return { - (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn - } + return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return self.db.runInteraction( "get_current_state_ids", _get_current_state_ids_txn diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2cfa5cf721..81a44184ca 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,8 +19,7 @@ import re import string from collections import Iterable -import six -from six import PY2, PY3 +from six import PY3 from six.moves import range from synapse.api.errors import Codes, SynapseError @@ -68,23 +67,6 @@ def is_ascii(s): return True -def to_ascii(s): - """Converts a string to ascii if it is ascii, otherwise leave it alone. - - If given None then will return None. - """ - if PY3: - return s - - if s is None: - return None - - try: - return s.encode("ascii") - except UnicodeEncodeError: - return s - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: -- cgit 1.5.1 From d4676910c91dd492ca5cc7c207969fa7bfe1bbee Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:17:06 +0100 Subject: remove miscellaneous PY2 code --- synapse/http/matrixfederationclient.py | 8 ++------ synapse/logging/utils.py | 10 ++-------- synapse/push/httppusher.py | 11 +++-------- synapse/rest/media/v1/_base.py | 27 +++++++++------------------ synapse/util/caches/__init__.py | 7 +------ synapse/util/stringutils.py | 28 +++++++--------------------- 6 files changed, 24 insertions(+), 67 deletions(-) (limited to 'synapse/util') diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 225a47e3c3..44077f5349 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,7 +19,7 @@ import random import sys from io import BytesIO -from six import PY3, raise_from, string_types +from six import raise_from, string_types from six.moves import urllib import attr @@ -70,11 +70,7 @@ incoming_responses_counter = Counter( MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 - -if PY3: - MAXINT = sys.maxsize -else: - MAXINT = sys.maxint +MAXINT = sys.maxsize _next_id = 1 diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 0c2527bd86..99049bb5d8 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -20,8 +20,6 @@ import time from functools import wraps from inspect import getcallargs -from six import PY3 - _TIME_FUNC_ID = 0 @@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args): logger = logging.getLogger(name) if logger.isEnabledFor(logging.DEBUG): - if PY3: - lineno = f.__code__.co_firstlineno - pathname = f.__code__.co_filename - else: - lineno = f.func_code.co_firstlineno - pathname = f.func_code.co_filename + lineno = f.__code__.co_firstlineno + pathname = f.__code__.co_filename record = logging.LogRecord( name=name, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 5bb17d1228..eaaa7afc91 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -import six - from prometheus_client import Counter from twisted.internet import defer @@ -28,9 +26,6 @@ from synapse.push import PusherConfigException from . import push_rule_evaluator, push_tools -if six.PY3: - long = int - logger = logging.getLogger(__name__) http_push_processed_counter = Counter( @@ -318,7 +313,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], @@ -347,7 +342,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, "tweaks": tweaks, } @@ -409,7 +404,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 503f2bed98..3689777266 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,7 +17,6 @@ import logging import os -from six import PY3 from six.moves import urllib from twisted.internet import defer @@ -324,23 +323,15 @@ def get_filename_from_headers(headers): upload_name_utf8 = upload_name_utf8[7:] # We have a filename*= section. This MUST be ASCII, and any UTF-8 # bytes are %-quoted. - if PY3: - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - else: - # On Python 2, we first unquote the %-encoded parts and then - # decode it strictly using UTF-8. - try: - upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8") - except UnicodeDecodeError: - pass + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass # If there isn't check for an ascii name. if not upload_name: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 4b8a0c7a8f..dd356bf156 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +from sys import intern from typing import Callable, Dict, Optional -import six -from six.moves import intern - import attr from prometheus_client.core import Gauge @@ -154,9 +152,6 @@ def intern_string(string): return None try: - if six.PY2: - string = string.encode("ascii") - return intern(string) except UnicodeEncodeError: return string diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 81a44184ca..08c86e92b8 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,9 +19,6 @@ import re import string from collections import Iterable -from six import PY3 -from six.moves import range - from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" @@ -46,24 +43,13 @@ def random_string_with_symbols(length): def is_ascii(s): - - if PY3: - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True - - try: - s.encode("ascii") - except UnicodeEncodeError: - return False - except UnicodeDecodeError: - return False - else: + if isinstance(s, bytes): + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False return True -- cgit 1.5.1 From a0f99f81b35b0e2ff600d7c72a0d71f15bf94f4c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 22 May 2020 10:17:36 +0100 Subject: Fix stacktrace mangling in `patch_inline_callbacks` (#7554) `Failure()` is more cunning than `Failure(e)`. --- changelog.d/7554.misc | 1 + synapse/util/patch_inline_callbacks.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7554.misc (limited to 'synapse/util') diff --git a/changelog.d/7554.misc b/changelog.d/7554.misc new file mode 100644 index 0000000000..7c35c46aa6 --- /dev/null +++ b/changelog.d/7554.misc @@ -0,0 +1 @@ +Fix some test code to not mangle stacktraces, to make it easier to debug errors. diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index fdff195771..2605f3c65b 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -186,10 +186,15 @@ def _check_yield_points(f: Callable, changes: List[str]): ) raise Exception(err) + # the wrapped function yielded a Deferred: yield it back up to the parent + # inlineCallbacks(). try: result = yield d - except Exception as e: - result = Failure(e) + except Exception: + # this will fish an earlier Failure out of the stack where possible, and + # thus is preferable to passing in an exeception to the Failure + # constructor, since it results in less stack-mangling. + result = Failure() if current_context() != expected_context: -- cgit 1.5.1 From eefc6b3a0d08cd2a64be7c78c0a4a651cc965be9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 May 2020 12:04:37 +0100 Subject: Don't apply cache factor to event cache. (#7578) This is already correctly done when we instansiate the cache, but wasn't when it got reloaded (which always happens at least once on startup). --- changelog.d/7578.bugfix | 1 + synapse/util/caches/lrucache.py | 4 ++++ tests/config/test_cache.py | 16 ++++++++++++++++ 3 files changed, 21 insertions(+) create mode 100644 changelog.d/7578.bugfix (limited to 'synapse/util') diff --git a/changelog.d/7578.bugfix b/changelog.d/7578.bugfix new file mode 100644 index 0000000000..cd29307361 --- /dev/null +++ b/changelog.d/7578.bugfix @@ -0,0 +1 @@ +Fix cache config to not apply cache factor to event cache. Regression in v1.14.0rc1. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 29fabac3cd..df4ea5901d 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -81,6 +81,7 @@ class LruCache(object): """ cache = cache_type() self.cache = cache # Used for introspection. + self.apply_cache_factor_from_config = apply_cache_factor_from_config # Save the original max size, and apply the default size factor. self._original_max_size = max_size @@ -294,6 +295,9 @@ class LruCache(object): Returns: bool: Whether the cache changed size or not. """ + if not self.apply_cache_factor_from_config: + return False + new_size = int(self._original_max_size * factor) if new_size != self.max_size: self.max_size = new_size diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 2920279125..b45e0cc536 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -125,3 +125,19 @@ class CacheConfigTests(TestCase): cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 150) + + def test_apply_cache_factor_from_config(self): + """Caches can disable applying cache factor updates, mainly used by + event cache size. + """ + + config = {"caches": {"event_cache_size": "10k"}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache( + max_size=t.caches.event_cache_size, apply_cache_factor_from_config=False, + ) + add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) + + self.assertEqual(cache.max_size, 10240) -- cgit 1.5.1