diff options
84 files changed, 1970 insertions, 730 deletions
diff --git a/CHANGES.md b/CHANGES.md index 75dc5fa893..38a0814bbf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,27 @@ +Synapse 1.21.2 (2020-10-15) +=========================== + +Debian packages and Docker images have been rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below. + +Security advisory +----------------- + +* HTML pages served via Synapse were vulnerable to cross-site scripting (XSS) + attacks. All server administrators are encouraged to upgrade. + ([\#8444](https://github.com/matrix-org/synapse/pull/8444)) + ([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891)) + + This fix was originally included in v1.21.0 but was missing a security advisory. + + This was reported by [Denis Kasak](https://github.com/dkasak). + +Bugfixes +-------- + +- Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530)) +- An updated version of the authlib dependency is included in the Docker and Debian images to fix an issue using OpenID Connect. See [\#8534](https://github.com/matrix-org/synapse/issues/8534) for details. + + Synapse 1.21.1 (2020-10-13) =========================== diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 524f82433d..c17e3b2399 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,6 +63,10 @@ run-time: ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder ``` +You can also provided the `-d` option, which will lint the files that have been +changed since the last git commit. This will often be significantly faster than +linting the whole codebase. + Before pushing new changes, ensure they don't produce linting errors. Commit any files that were corrected. diff --git a/changelog.d/8437.feature b/changelog.d/8437.feature new file mode 100644 index 0000000000..4abcccb326 --- /dev/null +++ b/changelog.d/8437.feature @@ -0,0 +1 @@ +Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices. diff --git a/changelog.d/8472.misc b/changelog.d/8472.misc new file mode 100644 index 0000000000..880f3f5e14 --- /dev/null +++ b/changelog.d/8472.misc @@ -0,0 +1 @@ +Add `-d` option to `./scripts-dev/lint.sh` to lint files that have changed since the last git commit. \ No newline at end of file diff --git a/changelog.d/8488.misc b/changelog.d/8488.misc new file mode 100644 index 0000000000..237cb3b311 --- /dev/null +++ b/changelog.d/8488.misc @@ -0,0 +1 @@ +Allow events to be sent to clients sooner when using sharded event persisters. diff --git a/changelog.d/8503.misc b/changelog.d/8503.misc new file mode 100644 index 0000000000..edb1be8aa8 --- /dev/null +++ b/changelog.d/8503.misc @@ -0,0 +1 @@ +Add user agent to user_daily_visits table. diff --git a/changelog.d/8515.misc b/changelog.d/8515.misc new file mode 100644 index 0000000000..1f8aa292d8 --- /dev/null +++ b/changelog.d/8515.misc @@ -0,0 +1 @@ +Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable. diff --git a/changelog.d/8517.bugfix b/changelog.d/8517.bugfix new file mode 100644 index 0000000000..1ab623c59f --- /dev/null +++ b/changelog.d/8517.bugfix @@ -0,0 +1 @@ +Fix error code for `/profile/{userId}/displayname` to be `M_BAD_JSON`. diff --git a/changelog.d/8526.doc b/changelog.d/8526.doc new file mode 100644 index 0000000000..cbf48680c1 --- /dev/null +++ b/changelog.d/8526.doc @@ -0,0 +1 @@ +Added note about docker in manhole.md regarding which ip address to bind to. Contributed by @Maquis196. diff --git a/changelog.d/8527.bugfix b/changelog.d/8527.bugfix new file mode 100644 index 0000000000..727e0ba299 --- /dev/null +++ b/changelog.d/8527.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.7.0 that could cause Synapse to insert values from non-state `m.room.retention` events into the `room_retention` database table. diff --git a/changelog.d/8529.doc b/changelog.d/8529.doc new file mode 100644 index 0000000000..6e710e6527 --- /dev/null +++ b/changelog.d/8529.doc @@ -0,0 +1 @@ +Document the new behaviour of the `allowed_lifetime_min` and `allowed_lifetime_max` settings in the room retention configuration. diff --git a/changelog.d/8536.bugfix b/changelog.d/8536.bugfix new file mode 100644 index 0000000000..8d238cc008 --- /dev/null +++ b/changelog.d/8536.bugfix @@ -0,0 +1 @@ +Fix not sending events over federation when using sharded event writers. diff --git a/changelog.d/8537.misc b/changelog.d/8537.misc new file mode 100644 index 0000000000..26309b5b93 --- /dev/null +++ b/changelog.d/8537.misc @@ -0,0 +1 @@ +Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`. diff --git a/changelog.d/8542.misc b/changelog.d/8542.misc new file mode 100644 index 0000000000..63149fd9b9 --- /dev/null +++ b/changelog.d/8542.misc @@ -0,0 +1 @@ +Improve database performance by executing more queries without starting transactions. diff --git a/changelog.d/8547.misc b/changelog.d/8547.misc new file mode 100644 index 0000000000..fafb1c8347 --- /dev/null +++ b/changelog.d/8547.misc @@ -0,0 +1 @@ +Enable mypy type checking for `synapse.util.caches`. diff --git a/changelog.d/8548.misc b/changelog.d/8548.misc new file mode 100644 index 0000000000..fba10bd731 --- /dev/null +++ b/changelog.d/8548.misc @@ -0,0 +1 @@ +Rename `Cache` to `DeferredCache`, to better reflect its purpose. diff --git a/debian/changelog b/debian/changelog index eeafd4f50a..8d873a4845 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +matrix-synapse-py3 (1.21.2) stable; urgency=medium + + [ Synapse Packaging team ] + * New synapse release 1.21.2. + + -- Synapse Packaging team <packages@matrix.org> Thu, 15 Oct 2020 09:23:27 -0400 + matrix-synapse-py3 (1.21.1) stable; urgency=medium [ Synapse Packaging team ] diff --git a/docs/manhole.md b/docs/manhole.md index 75b6ae40e0..37d1d7823c 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -5,22 +5,54 @@ The "manhole" allows server administrators to access a Python shell on a running Synapse installation. This is a very powerful mechanism for administration and debugging. +**_Security Warning_** + +Note that this will give administrative access to synapse to **all users** with +shell access to the server. It should therefore **not** be enabled in +environments where untrusted users have shell access. + +*** + To enable it, first uncomment the `manhole` listener configuration in -`homeserver.yaml`: +`homeserver.yaml`. The configuration is slightly different if you're using docker. + +#### Docker config + +If you are using Docker, set `bind_addresses` to `['0.0.0.0']` as shown: ```yaml listeners: - port: 9000 - bind_addresses: ['::1', '127.0.0.1'] + bind_addresses: ['0.0.0.0'] type: manhole ``` -(`bind_addresses` in the above is important: it ensures that access to the -manhole is only possible for local users). +When using `docker run` to start the server, you will then need to change the command to the following to include the +`manhole` port forwarding. The `-p 127.0.0.1:9000:9000` below is important: it +ensures that access to the `manhole` is only possible for local users. -Note that this will give administrative access to synapse to **all users** with -shell access to the server. It should therefore **not** be enabled in -environments where untrusted users have shell access. +```bash +docker run -d --name synapse \ + --mount type=volume,src=synapse-data,dst=/data \ + -p 8008:8008 \ + -p 127.0.0.1:9000:9000 \ + matrixdotorg/synapse:latest +``` + +#### Native config + +If you are not using docker, set `bind_addresses` to `['::1', '127.0.0.1']` as shown. +The `bind_addresses` in the example below is important: it ensures that access to the +`manhole` is only possible for local users). + +```yaml +listeners: + - port: 9000 + bind_addresses: ['::1', '127.0.0.1'] + type: manhole +``` + +#### Accessing synapse manhole Then restart synapse, and point an ssh client at port 9000 on localhost, using the username `matrix`: diff --git a/docs/message_retention_policies.md b/docs/message_retention_policies.md index 1dd60bdad9..75d2028e17 100644 --- a/docs/message_retention_policies.md +++ b/docs/message_retention_policies.md @@ -136,24 +136,34 @@ the server's database. ### Lifetime limits -**Note: this feature is mainly useful within a closed federation or on -servers that don't federate, because there currently is no way to -enforce these limits in an open federation.** - -Server admins can restrict the values their local users are allowed to -use for both `min_lifetime` and `max_lifetime`. These limits can be -defined as such in the `retention` section of the configuration file: +Server admins can set limits on the values of `max_lifetime` to use when +purging old events in a room. These limits can be defined as such in the +`retention` section of the configuration file: ```yaml allowed_lifetime_min: 1d allowed_lifetime_max: 1y ``` -Here, `allowed_lifetime_min` is the lowest value a local user can set -for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max` -is the highest value. Both parameters are optional (e.g. setting -`allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a -minimum and no maximum). +The limits are considered when running purge jobs. If necessary, the +effective value of `max_lifetime` will be brought between +`allowed_lifetime_min` and `allowed_lifetime_max` (inclusive). +This means that, if the value of `max_lifetime` defined in the room's state +is lower than `allowed_lifetime_min`, the value of `allowed_lifetime_min` +will be used instead. Likewise, if the value of `max_lifetime` is higher +than `allowed_lifetime_max`, the value of `allowed_lifetime_max` will be +used instead. + +In the example above, we ensure Synapse never deletes events that are less +than one day old, and that it always deletes events that are over a year +old. + +If a default policy is set, and its `max_lifetime` value is lower than +`allowed_lifetime_min` or higher than `allowed_lifetime_max`, the same +process applies. + +Both parameters are optional; if one is omitted Synapse won't use it to +adjust the effective value of `max_lifetime`. Like other settings in this section, these parameters can be expressed either as a duration or as a number of milliseconds. diff --git a/mypy.ini b/mypy.ini index f08fe992a4..b5db54ee3b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -15,6 +15,7 @@ files = synapse/events/builder.py, synapse/events/spamcheck.py, synapse/federation, + synapse/handlers/appservice.py, synapse/handlers/account_data.py, synapse/handlers/auth.py, synapse/handlers/cas_handler.py, @@ -64,9 +65,7 @@ files = synapse/streams, synapse/types.py, synapse/util/async_helpers.py, - synapse/util/caches/descriptors.py, - synapse/util/caches/response_cache.py, - synapse/util/caches/stream_change_cache.py, + synapse/util/caches, synapse/util/metrics.py, tests/replication, tests/test_utils, diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 0647993658..f2b65a2105 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -1,4 +1,4 @@ -#!/bin/sh +#!/bin/bash # # Runs linting scripts over the local Synapse checkout # isort - sorts import statements @@ -7,15 +7,90 @@ set -e -if [ $# -ge 1 ] -then - files=$* +usage() { + echo + echo "Usage: $0 [-h] [-d] [paths...]" + echo + echo "-d" + echo " Lint files that have changed since the last git commit." + echo + echo " If paths are provided and this option is set, both provided paths and those" + echo " that have changed since the last commit will be linted." + echo + echo " If no paths are provided and this option is not set, all files will be linted." + echo + echo " Note that paths with a file extension that is not '.py' will be excluded." + echo "-h" + echo " Display this help text." +} + +USING_DIFF=0 +files=() + +while getopts ":dh" opt; do + case $opt in + d) + USING_DIFF=1 + ;; + h) + usage + exit + ;; + \?) + echo "ERROR: Invalid option: -$OPTARG" >&2 + usage + exit + ;; + esac +done + +# Strip any options from the command line arguments now that +# we've finished processing them +shift "$((OPTIND-1))" + +if [ $USING_DIFF -eq 1 ]; then + # Check both staged and non-staged changes + for path in $(git diff HEAD --name-only); do + filename=$(basename "$path") + file_extension="${filename##*.}" + + # If an extension is present, and it's something other than 'py', + # then ignore this file + if [[ -n ${file_extension+x} && $file_extension != "py" ]]; then + continue + fi + + # Append this path to our list of files to lint + files+=("$path") + done +fi + +# Append any remaining arguments as files to lint +files+=("$@") + +if [[ $USING_DIFF -eq 1 ]]; then + # If we were asked to lint changed files, and no paths were found as a result... + if [ ${#files[@]} -eq 0 ]; then + # Then print and exit + echo "No files found to lint." + exit 0 + fi else - files="synapse tests scripts-dev scripts contrib synctl" + # If we were not asked to lint changed files, and no paths were found as a result, + # then lint everything! + if [[ -z ${files+x} ]]; then + # Lint all source code files and directories + files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py") + fi fi -echo "Linting these locations: $files" -isort $files -python3 -m black $files +echo "Linting these paths: ${files[*]}" +echo + +# Print out the commands being run +set -x + +isort "${files[@]}" +python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh -flake8 $files +flake8 "${files[@]}" diff --git a/setup.py b/setup.py index 926b1bc86f..08843fe2a3 100755 --- a/setup.py +++ b/setup.py @@ -15,12 +15,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import glob import os -from setuptools import setup, find_packages, Command -import sys +from setuptools import Command, find_packages, setup here = os.path.abspath(os.path.dirname(__file__)) diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi index 073b806d3c..fa307483fe 100644 --- a/stubs/sortedcontainers/__init__.pyi +++ b/stubs/sortedcontainers/__init__.pyi @@ -1,13 +1,12 @@ -from .sorteddict import ( - SortedDict, - SortedKeysView, - SortedItemsView, - SortedValuesView, -) +from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView +from .sortedlist import SortedKeyList, SortedList, SortedListWithKey __all__ = [ "SortedDict", "SortedKeysView", "SortedItemsView", "SortedValuesView", + "SortedKeyList", + "SortedList", + "SortedListWithKey", ] diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi new file mode 100644 index 0000000000..8f6086b3ff --- /dev/null +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -0,0 +1,177 @@ +# stub for SortedList. This is an exact copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +_T = TypeVar("_T") +_SL = TypeVar("_SL", bound=SortedList) +_SKL = TypeVar("_SKL", bound=SortedKeyList) +_Key = Callable[[_T], Any] +_Repr = Callable[[], str] + +def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ... + +class SortedList(MutableSequence[_T]): + + DEFAULT_LOAD_FACTOR: int = ... + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ..., + ): ... + # NB: currently mypy does not honour return type, see mypy #3307 + @overload + def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ... + @overload + def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ... + @overload + def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ... + @overload + def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ... + @property + def key(self) -> Optional[Callable[[_T], Any]]: ... + def _reset(self, load: int) -> None: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def _loc(self, pos: int, idx: int) -> int: ... + def _pos(self, idx: int) -> int: ... + def _build_index(self) -> None: ... + def __contains__(self, value: Any) -> bool: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + @overload + def _getitem(self, index: int) -> _T: ... + @overload + def _getitem(self, index: slice) -> List[_T]: ... + @overload + def __setitem__(self, index: int, value: _T) -> None: ... + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... + def __iter__(self) -> Iterator[_T]: ... + def __reversed__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def reverse(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_T]: ... + def _islice( + self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool, + ) -> Iterator[_T]: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ) -> Iterator[_T]: ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def _bisect_right(self, value: _T) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SL) -> _SL: ... + def __copy__(self: _SL) -> _SL: ... + def append(self, value: _T) -> None: ... + def extend(self, values: Iterable[_T]) -> None: ... + def insert(self, index: int, value: _T) -> None: ... + def pop(self, index: int = ...) -> _T: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __mul__(self: _SL, num: int) -> _SL: ... + def __rmul__(self: _SL, num: int) -> _SL: ... + def __imul__(self: _SL, num: int) -> _SL: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __lt__(self, other: Sequence[_T]) -> bool: ... + def __gt__(self, other: Sequence[_T]) -> bool: ... + def __le__(self, other: Sequence[_T]) -> bool: ... + def __ge__(self, other: Sequence[_T]) -> bool: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +class SortedKeyList(SortedList[_T]): + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> None: ... + def __new__( + cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> SortedKeyList[_T]: ... + @property + def key(self) -> Callable[[_T], Any]: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + # NB: Must be T to be safely passed to self.func, yet base class imposes Any + def __contains__(self, value: _T) -> bool: ... # type: ignore + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ): ... + def irange_key( + self, + min_key: Optional[Any] = ..., + max_key: Optional[Any] = ..., + inclusive: Tuple[bool, bool] = ..., + reserve: bool = ..., + ): ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def bisect_key_left(self, key: Any) -> int: ... + def _bisect_key_left(self, key: Any) -> int: ... + def bisect_key_right(self, key: Any) -> int: ... + def _bisect_key_right(self, key: Any) -> int: ... + def bisect_key(self, key: Any) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SKL) -> _SKL: ... + def __copy__(self: _SKL) -> _SKL: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __mul__(self: _SKL, num: int) -> _SKL: ... + def __rmul__(self: _SKL, num: int) -> _SKL: ... + def __imul__(self: _SKL, num: int) -> _SKL: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +SortedListWithKey = SortedKeyList diff --git a/synapse/__init__.py b/synapse/__init__.py index 722b53a67d..83b8e4897f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.21.1" +__version__ = "1.21.2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d53181deb1..1b511890aa 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -790,10 +790,6 @@ class FederationSenderHandler: send_queue.process_rows_for_federation(self.federation_sender, rows) await self.update_token(token) - # We also need to poke the federation sender when new events happen - elif stream_name == "events": - self.federation_sender.notify_new_events(token) - # ... and when new receipts happen elif stream_name == ReceiptsStream.NAME: await self._on_new_receipts(rows) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 13ec1f71a6..3862d9c08f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,14 +14,15 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable, List, Match, Optional from synapse.api.constants import EventTypes -from synapse.appservice.api import ApplicationServiceApi -from synapse.types import GroupID, get_domain_from_id +from synapse.events import EventBase +from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id from synapse.util.caches.descriptors import cached if TYPE_CHECKING: + from synapse.appservice.api import ApplicationServiceApi from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -32,38 +33,6 @@ class ApplicationServiceState: UP = "up" -class AppServiceTransaction: - """Represents an application service transaction.""" - - def __init__(self, service, id, events): - self.service = service - self.id = id - self.events = events - - async def send(self, as_api: ApplicationServiceApi) -> bool: - """Sends this transaction using the provided AS API interface. - - Args: - as_api: The API to use to send. - Returns: - True if the transaction was sent. - """ - return await as_api.push_bulk( - service=self.service, events=self.events, txn_id=self.id - ) - - async def complete(self, store: "DataStore") -> None: - """Completes this transaction as successful. - - Marks this transaction ID on the application service and removes the - transaction contents from the database. - - Args: - store: The database store to operate on. - """ - await store.complete_appservice_txn(service=self.service, txn_id=self.id) - - class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -91,6 +60,7 @@ class ApplicationService: protocols=None, rate_limited=True, ip_range_whitelist=None, + supports_ephemeral=False, ): self.token = token self.url = ( @@ -102,6 +72,7 @@ class ApplicationService: self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist + self.supports_ephemeral = supports_ephemeral if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -161,19 +132,21 @@ class ApplicationService: raise ValueError("Expected string for 'regex' in ns '%s'" % ns) return namespaces - def _matches_regex(self, test_string, namespace_key): + def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]: for regex_obj in self.namespaces[namespace_key]: if regex_obj["regex"].match(test_string): return regex_obj return None - def _is_exclusive(self, ns_key, test_string): + def _is_exclusive(self, ns_key: str, test_string: str) -> bool: regex_obj = self._matches_regex(test_string, ns_key) if regex_obj: return regex_obj["exclusive"] return False - async def _matches_user(self, event, store): + async def _matches_user( + self, event: Optional[EventBase], store: Optional["DataStore"] = None + ) -> bool: if not event: return False @@ -188,14 +161,23 @@ class ApplicationService: if not store: return False - does_match = await self._matches_user_in_member_list(event.room_id, store) + does_match = await self.matches_user_in_member_list(event.room_id, store) return does_match - @cached(num_args=1, cache_context=True) - async def _matches_user_in_member_list(self, room_id, store, cache_context): - member_list = await store.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) + @cached(num_args=1) + async def matches_user_in_member_list( + self, room_id: str, store: "DataStore" + ) -> bool: + """Check if this service is interested a room based upon it's membership + + Args: + room_id: The room to check. + store: The datastore to query. + + Returns: + True if this service would like to know about this room. + """ + member_list = await store.get_users_in_room(room_id) # check joined member events for user_id in member_list: @@ -203,12 +185,14 @@ class ApplicationService: return True return False - def _matches_room_id(self, event): + def _matches_room_id(self, event: EventBase) -> bool: if hasattr(event, "room_id"): return self.is_interested_in_room(event.room_id) return False - async def _matches_aliases(self, event, store): + async def _matches_aliases( + self, event: EventBase, store: Optional["DataStore"] = None + ) -> bool: if not store or not event: return False @@ -218,12 +202,15 @@ class ApplicationService: return True return False - async def is_interested(self, event, store=None) -> bool: + async def is_interested( + self, event: EventBase, store: Optional["DataStore"] = None + ) -> bool: """Check if this service is interested in this event. Args: - event(Event): The event to check. - store(DataStore) + event: The event to check. + store: The datastore to query. + Returns: True if this service would like to know about this event. """ @@ -231,39 +218,66 @@ class ApplicationService: if self._matches_room_id(event): return True + # This will check the namespaces first before + # checking the store, so should be run before _matches_aliases + if await self._matches_user(event, store): + return True + + # This will check the store, so should be run last if await self._matches_aliases(event, store): return True - if await self._matches_user(event, store): + return False + + @cached(num_args=1) + async def is_interested_in_presence( + self, user_id: UserID, store: "DataStore" + ) -> bool: + """Check if this service is interested a user's presence + + Args: + user_id: The user to check. + store: The datastore to query. + + Returns: + True if this service would like to know about presence for this user. + """ + # Find all the rooms the sender is in + if self.is_interested_in_user(user_id.to_string()): return True + room_ids = await store.get_rooms_for_user(user_id.to_string()) + # Then find out if the appservice is interested in any of those rooms + for room_id in room_ids: + if await self.matches_user_in_member_list(room_id, store): + return True return False - def is_interested_in_user(self, user_id): + def is_interested_in_user(self, user_id: str) -> bool: return ( - self._matches_regex(user_id, ApplicationService.NS_USERS) + bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) or user_id == self.sender ) - def is_interested_in_alias(self, alias): + def is_interested_in_alias(self, alias: str) -> bool: return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) - def is_interested_in_room(self, room_id): + def is_interested_in_room(self, room_id: str) -> bool: return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) - def is_exclusive_user(self, user_id): + def is_exclusive_user(self, user_id: str) -> bool: return ( self._is_exclusive(ApplicationService.NS_USERS, user_id) or user_id == self.sender ) - def is_interested_in_protocol(self, protocol): + def is_interested_in_protocol(self, protocol: str) -> bool: return protocol in self.protocols - def is_exclusive_alias(self, alias): + def is_exclusive_alias(self, alias: str) -> bool: return self._is_exclusive(ApplicationService.NS_ALIASES, alias) - def is_exclusive_room(self, room_id): + def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) def get_exclusive_user_regexes(self): @@ -276,14 +290,14 @@ class ApplicationService: if regex_obj["exclusive"] ] - def get_groups_for_user(self, user_id): + def get_groups_for_user(self, user_id: str) -> Iterable[str]: """Get the groups that this user is associated with by this AS Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - iterable[str]: an iterable that yields group_id strings. + An iterable that yields group_id strings. """ return ( regex_obj["group_id"] @@ -291,7 +305,7 @@ class ApplicationService: if "group_id" in regex_obj and regex_obj["regex"].match(user_id) ) - def is_rate_limited(self): + def is_rate_limited(self) -> bool: return self.rate_limited def __str__(self): @@ -300,3 +314,45 @@ class ApplicationService: dict_copy["token"] = "<redacted>" dict_copy["hs_token"] = "<redacted>" return "ApplicationService: %s" % (dict_copy,) + + +class AppServiceTransaction: + """Represents an application service transaction.""" + + def __init__( + self, + service: ApplicationService, + id: int, + events: List[EventBase], + ephemeral: List[JsonDict], + ): + self.service = service + self.id = id + self.events = events + self.ephemeral = ephemeral + + async def send(self, as_api: "ApplicationServiceApi") -> bool: + """Sends this transaction using the provided AS API interface. + + Args: + as_api: The API to use to send. + Returns: + True if the transaction was sent. + """ + return await as_api.push_bulk( + service=self.service, + events=self.events, + ephemeral=self.ephemeral, + txn_id=self.id, + ) + + async def complete(self, store: "DataStore") -> None: + """Completes this transaction as successful. + + Marks this transaction ID on the application service and removes the + transaction contents from the database. + + Args: + store: The database store to operate on. + """ + await store.complete_appservice_txn(service=self.service, txn_id=self.id) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e8f0793795..e366a982b8 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,12 +14,13 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, ThirdPartyInstanceID @@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient): key = (service.id, protocol) return await self.protocol_meta_cache.wrap(key, _get) - async def push_bulk(self, service, events, txn_id=None): + async def push_bulk( + self, + service: "ApplicationService", + events: List[EventBase], + ephemeral: List[JsonDict], + txn_id: Optional[int] = None, + ): if service.url is None: return True @@ -211,15 +218,19 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning( "push_bulk: Missing txn ID sending events to %s", service.url ) - txn_id = str(0) - txn_id = str(txn_id) + txn_id = 0 + + uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) + + # Never send ephemeral events to appservices that do not support it + if service.supports_ephemeral: + body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} + else: + body = {"events": events} - uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: await self.put_json( - uri=uri, - json_body={"events": events}, - args={"access_token": service.hs_token}, + uri=uri, json_body=body, args={"access_token": service.hs_token}, ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 8eb8c6f51c..ad3c408519 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -49,10 +49,13 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging +from typing import List -from synapse.appservice import ApplicationServiceState +from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -82,8 +85,13 @@ class ApplicationServiceScheduler: for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service, event): - self.queuer.enqueue(service, event) + def submit_event_for_as(self, service: ApplicationService, event: EventBase): + self.queuer.enqueue_event(service, event) + + def submit_ephemeral_events_for_as( + self, service: ApplicationService, events: List[JsonDict] + ): + self.queuer.enqueue_ephemeral(service, events) class _ServiceQueuer: @@ -96,17 +104,15 @@ class _ServiceQueuer: def __init__(self, txn_ctrl, clock): self.queued_events = {} # dict of {service_id: [events]} + self.queued_ephemeral = {} # dict of {service_id: [events]} # the appservices which currently have a transaction in flight self.requests_in_flight = set() self.txn_ctrl = txn_ctrl self.clock = clock - def enqueue(self, service, event): - self.queued_events.setdefault(service.id, []).append(event) - + def _start_background_request(self, service): # start a sender for this appservice if we don't already have one - if service.id in self.requests_in_flight: return @@ -114,7 +120,15 @@ class _ServiceQueuer: "as-sender-%s" % (service.id,), self._send_request, service ) - async def _send_request(self, service): + def enqueue_event(self, service: ApplicationService, event: EventBase): + self.queued_events.setdefault(service.id, []).append(event) + self._start_background_request(service) + + def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): + self.queued_ephemeral.setdefault(service.id, []).extend(events) + self._start_background_request(service) + + async def _send_request(self, service: ApplicationService): # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -123,10 +137,11 @@ class _ServiceQueuer: try: while True: events = self.queued_events.pop(service.id, []) - if not events: + ephemeral = self.queued_ephemeral.pop(service.id, []) + if not events and not ephemeral: return try: - await self.txn_ctrl.send(service, events) + await self.txn_ctrl.send(service, events, ephemeral) except Exception: logger.exception("AS request failed") finally: @@ -158,9 +173,16 @@ class _TransactionController: # for UTs self.RECOVERER_CLASS = _Recoverer - async def send(self, service, events): + async def send( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict] = [], + ): try: - txn = await self.store.create_appservice_txn(service=service, events=events) + txn = await self.store.create_appservice_txn( + service=service, events=events, ephemeral=ephemeral + ) service_is_up = await self._is_service_up(service) if service_is_up: sent = await txn.send(self.as_api) @@ -204,7 +226,7 @@ class _TransactionController: recoverer.recover() logger.info("Now %i active recoverers", len(self.recoverers)) - async def _is_service_up(self, service): + async def _is_service_up(self, service: ApplicationService) -> bool: state = await self.store.get_appservice_state(service) return state == ApplicationServiceState.UP or state is None diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 8ed3e24258..746fc3cc02 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -160,6 +160,8 @@ def _load_appservice(hostname, as_info, config_filename): if as_info.get("ip_range_whitelist"): ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist")) + supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -168,6 +170,7 @@ def _load_appservice(hostname, as_info, config_filename): hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], + supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, diff --git a/synapse/events/builder.py b/synapse/events/builder.py index b6c47be646..df4f950fec 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -97,32 +97,37 @@ class EventBuilder: def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids: List[str]) -> EventBase: + async def build( + self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]] + ) -> EventBase: """Transform into a fully signed and hashed event Args: prev_event_ids: The event IDs to use as the prev events + auth_event_ids: The event IDs to use as the auth events. + Should normally be set to None, which will cause them to be calculated + based on the room state at the prev_events. Returns: The signed and hashed event. """ - - state_ids = await self._state.get_current_state_ids( - self.room_id, prev_event_ids - ) - auth_ids = self._auth.compute_auth_events(self, state_ids) + if auth_event_ids is None: + state_ids = await self._state.get_current_state_ids( + self.room_id, prev_event_ids + ) + auth_event_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: # The types of auth/prev events changes between event versions. auth_events = await self._store.add_event_hashes( - auth_ids + auth_event_ids ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] prev_events = await self._store.add_event_hashes( prev_event_ids ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: - auth_events = auth_ids + auth_events = auth_event_ids prev_events = prev_event_ids old_depth = await self._store.get_max_depth_of(prev_event_ids) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 9df35b54ba..5f9af8529b 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -83,6 +83,9 @@ class EventValidator: Args: event (FrozenEvent): The event to validate. """ + if not event.is_state(): + raise SynapseError(code=400, msg="must be a state event") + min_lifetime = event.content.get("min_lifetime") max_lifetime = event.content.get("max_lifetime") diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 8e46957d15..5f1bf492c1 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -188,7 +188,7 @@ class FederationRemoteSendQueue: for key in keys[:i]: del self.edus[key] - def notify_new_events(self, current_id): + def notify_new_events(self, max_token): """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index e33b29a42c..604cfd1935 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -40,7 +40,7 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import ReadReceipt +from synapse.types import ReadReceipt, RoomStreamToken from synapse.util.metrics import Measure, measure_func logger = logging.getLogger(__name__) @@ -154,10 +154,15 @@ class FederationSender: self._per_destination_queues[destination] = queue return queue - def notify_new_events(self, current_id: int) -> None: + def notify_new_events(self, max_token: RoomStreamToken) -> None: """This gets called when we have some new events we might want to send out to other servers. """ + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + current_id = max_token.stream + self._last_poked_id = max(current_id, self._last_poked_id) if self._is_processing: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9d4e87dad6..07240d3a14 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, List, Optional from prometheus_client import Counter @@ -21,12 +22,16 @@ from twisted.internet import defer import synapse from synapse.api.constants import EventTypes +from synapse.appservice import ApplicationService +from synapse.events import EventBase +from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( event_processing_loop_counter, event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import Collection, JsonDict, RoomStreamToken, UserID from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -43,19 +48,22 @@ class ApplicationServicesHandler: self.started_scheduler = False self.clock = hs.get_clock() self.notify_appservices = hs.config.notify_appservices + self.event_sources = hs.get_event_sources() self.current_max = 0 self.is_processing = False - async def notify_interested_services(self, current_id): + async def notify_interested_services(self, max_token: RoomStreamToken): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any prolonged length of time. - - Args: - current_id(int): The current maximum ID. """ + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + current_id = max_token.stream + services = self.store.get_app_services() if not services or not self.notify_appservices: return @@ -79,7 +87,7 @@ class ApplicationServicesHandler: if not events: break - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -158,6 +166,104 @@ class ApplicationServicesHandler: finally: self.is_processing = False + async def notify_interested_services_ephemeral( + self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [], + ): + """This is called by the notifier in the background + when a ephemeral event handled by the homeserver. + + This will determine which appservices + are interested in the event, and submit them. + + Events will only be pushed to appservices + that have opted into ephemeral events + + Args: + stream_key: The stream the event came from. + new_token: The latest stream token + users: The user(s) involved with the event. + """ + services = [ + service + for service in self.store.get_app_services() + if service.supports_ephemeral + ] + if not services or not self.notify_appservices: + return + logger.info("Checking interested services for %s" % (stream_key)) + with Measure(self.clock, "notify_interested_services_ephemeral"): + for service in services: + # Only handle typing if we have the latest token + if stream_key == "typing_key" and new_token is not None: + events = await self._handle_typing(service, new_token) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + # We don't persist the token for typing_key for performance reasons + elif stream_key == "receipt_key": + events = await self._handle_receipts(service) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + await self.store.set_type_stream_id_for_appservice( + service, "read_receipt", new_token + ) + elif stream_key == "presence_key": + events = await self._handle_presence(service, users) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + await self.store.set_type_stream_id_for_appservice( + service, "presence", new_token + ) + + async def _handle_typing(self, service: ApplicationService, new_token: int): + typing_source = self.event_sources.sources["typing"] + # Get the typing events from just before current + typing, _ = await typing_source.get_new_events_as( + service=service, + # For performance reasons, we don't persist the previous + # token in the DB and instead fetch the latest typing information + # for appservices. + from_key=new_token - 1, + ) + return typing + + async def _handle_receipts(self, service: ApplicationService): + from_key = await self.store.get_type_stream_id_for_appservice( + service, "read_receipt" + ) + receipts_source = self.event_sources.sources["receipt"] + receipts, _ = await receipts_source.get_new_events_as( + service=service, from_key=from_key + ) + return receipts + + async def _handle_presence( + self, service: ApplicationService, users: Collection[UserID] + ): + events = [] # type: List[JsonDict] + presence_source = self.event_sources.sources["presence"] + from_key = await self.store.get_type_stream_id_for_appservice( + service, "presence" + ) + for user in users: + interested = await service.is_interested_in_presence(user, self.store) + if not interested: + continue + presence_events, _ = await presence_source.get_new_events( + user=user, service=service, from_key=from_key, + ) + time_now = self.clock.time_msec() + presence_events = [ + { + "type": "m.presence", + "sender": event.user_id, + "content": format_user_presence_state( + event, time_now, include_user_id=False + ), + } + for event in presence_events + ] + events = events + presence_events + async def query_user_exists(self, user_id): """Check if any application service knows this user_id exists. @@ -220,7 +326,7 @@ class ApplicationServicesHandler: async def get_3pe_protocols(self, only_protocol=None): services = self.store.get_app_services() - protocols = {} + protocols = {} # type: Dict[str, List[JsonDict]] # Collect up all the individual protocol responses out of the ASes for s in services: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0c6aec347e..7f00805a91 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -437,9 +437,9 @@ class EventCreationHandler: self, requester: Requester, event_dict: dict, - token_id: Optional[str] = None, txn_id: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -453,13 +453,18 @@ class EventCreationHandler: Args: requester event_dict: An entire event - token_id txn_id prev_event_ids: the forward extremities to use as the prev_events for the new event. If None, they will be requested from the database. + + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + require_consent: Whether to check if the requester has consented to the privacy policy. Raises: @@ -511,14 +516,17 @@ class EventCreationHandler: if require_consent and not is_exempt: await self.assert_accepted_privacy_policy(requester) - if token_id is not None: - builder.internal_metadata.token_id = token_id + if requester.access_token_id is not None: + builder.internal_metadata.token_id = requester.access_token_id if txn_id is not None: builder.internal_metadata.txn_id = txn_id event, context = await self.create_new_client_event( - builder=builder, requester=requester, prev_event_ids=prev_event_ids, + builder=builder, + requester=requester, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -726,7 +734,7 @@ class EventCreationHandler: return event, event.internal_metadata.stream_ordering event, context = await self.create_event( - requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id + requester, event_dict, txn_id=txn_id ) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( @@ -757,6 +765,7 @@ class EventCreationHandler: builder: EventBuilder, requester: Optional[Requester] = None, prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -769,6 +778,11 @@ class EventCreationHandler: If None, they will be requested from the database. + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + Returns: Tuple of created event, context """ @@ -790,7 +804,9 @@ class EventCreationHandler: builder.type == EventTypes.Create or len(prev_event_ids) > 0 ), "Attempting to create an event with no prev_events" - event = await builder.build(prev_event_ids=prev_event_ids) + event = await builder.build( + prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids + ) context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 7225923757..c242c409cf 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple +from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler -from synapse.types import ReadReceipt, get_domain_from_id +from synapse.types import JsonDict, ReadReceipt, get_domain_from_id from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -140,5 +142,36 @@ class ReceiptEventSource: return (events, to_key) + async def get_new_events_as( + self, from_key: int, service: ApplicationService + ) -> Tuple[List[JsonDict], int]: + """Returns a set of new receipt events that an appservice + may be interested in. + + Args: + from_key: the stream position at which events should be fetched from + service: The appservice which may be interested + """ + from_key = int(from_key) + to_key = self.get_current_key() + + if from_key == to_key: + return [], to_key + + # We first need to fetch all new receipts + rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms( + from_key=from_key, to_key=to_key + ) + + # Then filter down to rooms that the AS can read + events = [] + for room_id, event in rooms_to_events.items(): + if not await service.matches_user_in_member_list(room_id, self.store): + continue + + events.append(event) + + return (events, to_key) + def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 93ed51063a..ec300d8877 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -214,7 +214,6 @@ class RoomCreationHandler(BaseHandler): "replacement_room": new_room_id, }, }, - token_id=requester.access_token_id, ) old_room_version = await self.store.get_room_version_id(old_room_id) await self.auth.check_from_context( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0080eeaf8d..ec784030e9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,12 +17,10 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union - -from unpaddedbase64 import encode_base64 +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -31,12 +29,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.api.room_versions import EventFormatVersions -from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase -from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext -from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer @@ -193,7 +187,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # For backwards compatibility: "membership": membership, }, - token_id=requester.access_token_id, txn_id=txn_id, prev_event_ids=prev_event_ids, require_consent=require_consent, @@ -1133,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id = invite_event.room_id target_user = invite_event.state_key - room_version = await self.store.get_room_version(room_id) content["membership"] = Membership.LEAVE - # the auth events for the new event are the same as that of the invite, plus - # the invite itself. - # - # the prev_events are just the invite. - invite_hash = invite_event.event_id # type: Union[str, Tuple] - if room_version.event_format == EventFormatVersions.V1: - alg, h = compute_event_reference_hash(invite_event) - invite_hash = (invite_event.event_id, {alg: encode_base64(h)}) - - auth_events = tuple(invite_event.auth_events) + (invite_hash,) - prev_events = (invite_hash,) - - # we cap depth of generated events, to ensure that they are not - # rejected by other servers (and so that they can be persisted in - # the db) - depth = min(invite_event.depth + 1, MAX_DEPTH) - event_dict = { - "depth": depth, - "auth_events": auth_events, - "prev_events": prev_events, "type": EventTypes.Member, "room_id": room_id, "sender": target_user, @@ -1165,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler): "state_key": target_user, } - event = create_local_event_from_event_dict( - clock=self.clock, - hostname=self.hs.hostname, - signing_key=self.hs.signing_key, - room_version=room_version, - event_dict=event_dict, + # the auth events for the new event are the same as that of the invite, plus + # the invite itself. + # + # the prev_events are just the invite. + prev_event_ids = [invite_event.event_id] + auth_event_ids = invite_event.auth_event_ids() + prev_event_ids + + event, context = await self.event_creation_handler.create_event( + requester, + event_dict, + txn_id=txn_id, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, ) event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True - if txn_id is not None: - event.internal_metadata.txn_id = txn_id - if requester.access_token_id is not None: - event.internal_metadata.token_id = requester.access_token_id - - EventValidator().validate_new(event, self.config) - context = await self.state_handler.compute_event_context(event) - context.app_service = requester.app_service result_event = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[UserID.from_string(target_user)], ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a306631094..b527724bc4 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3cbfc2d780..d3692842e3 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random from collections import namedtuple from typing import TYPE_CHECKING, List, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError +from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -430,6 +430,33 @@ class TypingNotificationEventSource: "content": {"user_ids": list(typing)}, } + async def get_new_events_as( + self, from_key: int, service: ApplicationService + ) -> Tuple[List[JsonDict], int]: + """Returns a set of new typing events that an appservice + may be interested in. + + Args: + from_key: the stream position at which events should be fetched from + service: The appservice which may be interested + """ + with Measure(self.clock, "typing.get_new_events_as"): + from_key = int(from_key) + handler = self.get_typing_handler() + + events = [] + for room_id in handler._room_serials.keys(): + if handler._room_serials[room_id] <= from_key: + continue + if not await service.matches_user_in_member_list( + room_id, handler.store + ): + continue + + events.append(self._make_event_for(room_id)) + + return (events, handler._latest_room_serial) + async def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) diff --git a/synapse/notifier.py b/synapse/notifier.py index 13adeed01e..2e993411b9 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -319,19 +319,35 @@ class Notifier: ) if self.federation_sender: - self.federation_sender.notify_new_events(max_room_stream_token.stream) + self.federation_sender.notify_new_events(max_room_stream_token) async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: await self.appservice_handler.notify_interested_services( - max_room_stream_token.stream + max_room_stream_token + ) + except Exception: + logger.exception("Error notifying application services of event") + + async def _notify_app_services_ephemeral( + self, + stream_key: str, + new_token: Union[int, RoomStreamToken], + users: Collection[UserID] = [], + ): + try: + stream_token = None + if isinstance(new_token, int): + stream_token = new_token + await self.appservice_handler.notify_interested_services_ephemeral( + stream_key, stream_token, users ) except Exception: logger.exception("Error notifying application services of event") async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: - await self._pusher_pool.on_new_notifications(max_room_stream_token.stream) + await self._pusher_pool.on_new_notifications(max_room_stream_token) except Exception: logger.exception("Error pusher pool of event") @@ -367,6 +383,15 @@ class Notifier: self.notify_replication() + # Notify appservices + run_as_background_process( + "_notify_app_services_ephemeral", + self._notify_app_services_ephemeral, + stream_key, + new_token, + users, + ) + def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 28bd8ab748..c6763971ee 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -18,6 +18,7 @@ import logging from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import RoomStreamToken logger = logging.getLogger(__name__) @@ -91,7 +92,12 @@ class EmailPusher: pass self.timed_call = None - def on_new_notifications(self, max_stream_ordering): + def on_new_notifications(self, max_token: RoomStreamToken): + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_ordering = max_token.stream + if self.max_stream_ordering: self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 26706bf3e1..793d0db2d9 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException +from synapse.types import RoomStreamToken from . import push_rule_evaluator, push_tools @@ -114,7 +115,12 @@ class HttpPusher: if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, max_stream_ordering): + def on_new_notifications(self, max_token: RoomStreamToken): + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_ordering = max_token.stream + self.max_stream_ordering = max( max_stream_ordering, self.max_stream_ordering or 0 ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 76150e117b..0080c68ce2 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -24,6 +24,7 @@ from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher from synapse.push.httppusher import HttpPusher from synapse.push.pusher import PusherFactory +from synapse.types import RoomStreamToken from synapse.util.async_helpers import concurrently_execute if TYPE_CHECKING: @@ -186,11 +187,16 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, max_stream_id: int): + async def on_new_notifications(self, max_token: RoomStreamToken): if not self.pushers: # nothing to do here. return + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_id = max_token.stream + if max_stream_id < self._last_room_stream_id_seen: # Nothing to do return @@ -214,7 +220,7 @@ class PusherPool: if u in self.pushers: for p in self.pushers[u].values(): - p.on_new_notifications(max_stream_id) + p.on_new_notifications(max_token) except Exception: logger.exception("Exception in pusher on_new_notifications") diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 1f8dafe7ea..4b0ea0cc01 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.deferred_cache import DeferredCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 - ) + ) # type: DeferredCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index b686cd671f..e7fcd2b1ff 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -59,7 +59,9 @@ class ProfileDisplaynameRestServlet(RestServlet): try: new_name = content["displayname"] except Exception: - return 400, "Unable to parse name" + raise SynapseError( + code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON, + ) await self.profile_handler.set_displayname(user, requester, new_name, is_admin) diff --git a/synapse/server.py b/synapse/server.py index f921ee4b53..21a232bbd9 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -205,7 +205,13 @@ class HomeServer(metaclass=abc.ABCMeta): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): + def __init__( + self, + hostname: str, + config: HomeServerConfig, + reactor=None, + version_string="Synapse", + ): """ Args: hostname : The hostname for the server. @@ -236,11 +242,9 @@ class HomeServer(metaclass=abc.ABCMeta): burst_count=config.rc_registration.burst_count, ) - self.datastores = None # type: Optional[Databases] + self.version_string = version_string - # Other kwargs are explicit dependencies - for depname in kwargs: - setattr(self, depname, kwargs[depname]) + self.datastores = None # type: Optional[Databases] def get_instance_id(self) -> str: """A unique ID for this synapse process instance. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0ba3a025cf..763722d6bc 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -893,6 +893,12 @@ class DatabasePool: attempts = 0 while True: try: + # We can autocommit if we are going to use native upserts + autocommit = ( + self.engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ) + return await self.runInteraction( desc, self.simple_upsert_txn, @@ -901,6 +907,7 @@ class DatabasePool: values, insertion_values, lock=lock, + db_autocommit=autocommit, ) except self.engine.module.IntegrityError as e: attempts += 1 @@ -1063,6 +1070,43 @@ class DatabasePool: ) txn.execute(sql, list(allvalues.values())) + async def simple_upsert_many( + self, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[Any]], + desc: str, + ) -> None: + """ + Upsert, many times. + + Args: + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. + """ + + # We can autocommit if we are going to use native upserts + autocommit = ( + self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables + ) + + return await self.runInteraction( + desc, + self.simple_upsert_many_txn, + table, + key_names, + key_values, + value_names, + value_values, + db_autocommit=autocommit, + ) + def simple_upsert_many_txn( self, txn: LoggingTransaction, @@ -1214,7 +1258,13 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ return await self.runInteraction( - desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + desc, + self.simple_select_one_txn, + table, + keyvalues, + retcols, + allow_none, + db_autocommit=True, ) @overload @@ -1265,6 +1315,7 @@ class DatabasePool: keyvalues, retcol, allow_none=allow_none, + db_autocommit=True, ) @overload @@ -1346,7 +1397,12 @@ class DatabasePool: Results in a list """ return await self.runInteraction( - desc, self.simple_select_onecol_txn, table, keyvalues, retcol + desc, + self.simple_select_onecol_txn, + table, + keyvalues, + retcol, + db_autocommit=True, ) async def simple_select_list( @@ -1371,7 +1427,12 @@ class DatabasePool: A list of dictionaries. """ return await self.runInteraction( - desc, self.simple_select_list_txn, table, keyvalues, retcols + desc, + self.simple_select_list_txn, + table, + keyvalues, + retcols, + db_autocommit=True, ) @classmethod @@ -1450,6 +1511,7 @@ class DatabasePool: chunk, keyvalues, retcols, + db_autocommit=True, ) results.extend(rows) @@ -1548,7 +1610,12 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ await self.runInteraction( - desc, self.simple_update_one_txn, table, keyvalues, updatevalues + desc, + self.simple_update_one_txn, + table, + keyvalues, + updatevalues, + db_autocommit=True, ) @classmethod @@ -1607,7 +1674,9 @@ class DatabasePool: keyvalues: dict of column names and values to select the row with desc: description of the transaction, for logging and metrics """ - await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + await self.runInteraction( + desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True, + ) @staticmethod def simple_delete_one_txn( @@ -1646,7 +1715,9 @@ class DatabasePool: Returns: The number of deleted rows. """ - return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + return await self.runInteraction( + desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True + ) @staticmethod def simple_delete_txn( @@ -1694,7 +1765,13 @@ class DatabasePool: Number rows deleted """ return await self.runInteraction( - desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + desc, + self.simple_delete_many_txn, + table, + column, + iterable, + keyvalues, + db_autocommit=True, ) @staticmethod @@ -1860,7 +1937,13 @@ class DatabasePool: """ return await self.runInteraction( - desc, self.simple_search_list_txn, table, term, col, retcols + desc, + self.simple_search_list_txn, + table, + term, + col, + retcols, + db_autocommit=True, ) @classmethod diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 85f6b1e3fd..43bf0f649a 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -15,12 +15,15 @@ # limitations under the License. import logging import re +from typing import List -from synapse.appservice import AppServiceTransaction +from synapse.appservice import ApplicationService, AppServiceTransaction from synapse.config.appservice import load_appservices +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.types import JsonDict from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -172,15 +175,23 @@ class ApplicationServiceTransactionWorkerStore( "application_services_state", {"as_id": service.id}, {"state": state} ) - async def create_appservice_txn(self, service, events): + async def create_appservice_txn( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict], + ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service - with the given list of events. + with the given list of events. Ephemeral events are NOT persisted to the + database and are not resent if a transaction is retried. Args: - service(ApplicationService): The service who the transaction is for. - events(list<Event>): A list of events to put in the transaction. + service: The service who the transaction is for. + events: A list of persistent events to put in the transaction. + ephemeral: A list of ephemeral events to put in the transaction. + Returns: - AppServiceTransaction: A new transaction. + A new transaction. """ def _create_appservice_txn(txn): @@ -207,7 +218,9 @@ class ApplicationServiceTransactionWorkerStore( "VALUES(?,?,?)", (service.id, new_txn_id, event_ids), ) - return AppServiceTransaction(service=service, id=new_txn_id, events=events) + return AppServiceTransaction( + service=service, id=new_txn_id, events=events, ephemeral=ephemeral + ) return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn @@ -296,7 +309,9 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + return AppServiceTransaction( + service=service, id=entry["txn_id"], events=events, ephemeral=[] + ) def _get_last_txn(self, txn, service_id): txn.execute( @@ -320,7 +335,7 @@ class ApplicationServiceTransactionWorkerStore( ) async def get_new_events_for_appservice(self, current_id, limit): - """Get all new evnets""" + """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): sql = ( @@ -351,6 +366,39 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, events + async def get_type_stream_id_for_appservice( + self, service: ApplicationService, type: str + ) -> int: + def get_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + "SELECT ? FROM application_services_state WHERE as_id=?", + (stream_id_type, service.id,), + ) + last_txn_id = txn.fetchone() + if last_txn_id is None or last_txn_id[0] is None: # no row exists + return 0 + else: + return int(last_txn_id[0]) + + return await self.db_pool.runInteraction( + "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn + ) + + async def set_type_stream_id_for_appservice( + self, service: ApplicationService, type: str, pos: int + ) -> None: + def set_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + "UPDATE ? SET device_list_stream_id = ? WHERE as_id=?", + (stream_id_type, pos, service.id), + ) + + await self.db_pool.runInteraction( + "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + ) + class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): # This is currently empty due to there not being any AS storage functions diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index a25a888443..9e66e6648a 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) @@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88fd97e1df..e662a20d24 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import Cache, cached, cachedList +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -1004,7 +1005,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = Cache( + self.device_id_exists_cache = DeferredCache( name="device_id_exists", keylen=2, max_entries=10000 ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fdb17745f6..ba3b1769b0 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1270,6 +1270,10 @@ class PersistEventsStore: ) def _store_retention_policy_for_room_txn(self, txn, event): + if not event.is_state(): + logger.debug("Ignoring non-state m.room.retention event") + return + if hasattr(event, "content") and ( "min_lifetime" in event.content or "max_lifetime" in event.content ): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3ec4d1d9c2..ff150f0be7 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -145,7 +146,7 @@ class EventsWorkerStore(SQLBaseStore): self._cleanup_old_transaction_ids, ) - self._get_event_cache = Cache( + self._get_event_cache = DeferredCache( "*getEvent*", keylen=3, max_entries=hs.config.caches.event_cache_size, diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index ad43bb05ab..f8f4bb9b3f 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - await self.db_pool.runInteraction( - "store_server_verify_keys", - self.db_pool.simple_upsert_many_txn, + await self.db_pool.simple_upsert_many( table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, + desc="store_server_verify_keys", ) invalidate = self._get_server_verify_key.invalidate diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 0acf0617ca..79b01d16f9 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -281,9 +281,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): a_day_in_milliseconds = 24 * 60 * 60 * 1000 now = self._clock.time_msec() + # A note on user_agent. Technically a given device can have multiple + # user agents, so we need to decide which one to pick. We could have handled this + # in number of ways, but given that we don't _that_ much have gone for MAX() + # For more details of the other options considered see + # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111 sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? + INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent) + SELECT u.user_id, u.device_id, ?, MAX(u.user_agent) FROM user_ips AS u LEFT JOIN ( SELECT user_id, device_id, timestamp FROM user_daily_visits @@ -294,7 +299,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): WHERE last_seen > ? AND last_seen <= ? AND udv.timestamp IS NULL AND users.is_guest=0 AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id + GROUP BY u.user_id, u.device_id, u.user_agent """ # This means that the day has rolled over but there could still diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c79ddff680..5cdf16521c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList @@ -274,6 +275,60 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): } return results + @cached(num_args=2,) + async def get_linearized_receipts_for_all_rooms( + self, to_key: int, from_key: Optional[int] = None + ) -> Dict[str, JsonDict]: + """Get receipts for all rooms between two stream_ids. + + Args: + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + A dictionary of roomids to a list of receipts. + """ + + def f(txn): + if from_key: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, [from_key, to_key]) + else: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? + """ + + txn.execute(sql, [to_key]) + + return self.db_pool.cursor_to_dict(txn) + + txn_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_all_rooms", f + ) + + results = {} + for row in txn_results: + # We want a single event per room, since we want to batch the + # receipts by room, event and type. + room_event = results.setdefault( + row["room_id"], + {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + ) + + # The content is of the form: + # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } + event_entry = room_event["content"].setdefault(row["event_id"], {}) + receipt_type = event_entry.setdefault(row["receipt_type"], {}) + + receipt_type[row["user_id"]] = db_to_json(row["data"]) + + return results + async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: diff --git a/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql new file mode 100644 index 0000000000..b0b5dcddce --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- Add new column to user_daily_visits to track user agent +ALTER TABLE user_daily_visits + ADD COLUMN user_agent TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql b/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql new file mode 100644 index 0000000000..20f5a95a24 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE application_services_state + ADD COLUMN read_receipt_stream_id INT, + ADD COLUMN presence_stream_id INT; \ No newline at end of file diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 7d46090267..59207cadd4 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -208,42 +208,56 @@ class TransactionStore(TransactionWorkerStore): """ self._destination_retry_cache.pop(destination, None) - return await self.db_pool.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings, - destination, - failure_ts, - retry_last_ts, - retry_interval, - ) + if self.database_engine.can_native_upsert: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_native, + destination, + failure_ts, + retry_last_ts, + retry_interval, + db_autocommit=True, # Safe as its a single upsert + ) + else: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_emulated, + destination, + failure_ts, + retry_last_ts, + retry_interval, + ) - def _set_destination_retry_timings( + def _set_destination_retry_timings_native( self, txn, destination, failure_ts, retry_last_ts, retry_interval ): + assert self.database_engine.can_native_upsert + + # Upsert retry time interval if retry_interval is zero (i.e. we're + # resetting it) or greater than the existing retry interval. + # + # WARNING: This is executed in autocommit, so we shouldn't add any more + # SQL calls in here (without being very careful). + sql = """ + INSERT INTO destinations ( + destination, failure_ts, retry_last_ts, retry_interval + ) + VALUES (?, ?, ?, ?) + ON CONFLICT (destination) DO UPDATE SET + failure_ts = EXCLUDED.failure_ts, + retry_last_ts = EXCLUDED.retry_last_ts, + retry_interval = EXCLUDED.retry_interval + WHERE + EXCLUDED.retry_interval = 0 + OR destinations.retry_interval IS NULL + OR destinations.retry_interval < EXCLUDED.retry_interval + """ - if self.database_engine.can_native_upsert: - # Upsert retry time interval if retry_interval is zero (i.e. we're - # resetting it) or greater than the existing retry interval. - - sql = """ - INSERT INTO destinations ( - destination, failure_ts, retry_last_ts, retry_interval - ) - VALUES (?, ?, ?, ?) - ON CONFLICT (destination) DO UPDATE SET - failure_ts = EXCLUDED.failure_ts, - retry_last_ts = EXCLUDED.retry_last_ts, - retry_interval = EXCLUDED.retry_interval - WHERE - EXCLUDED.retry_interval = 0 - OR destinations.retry_interval IS NULL - OR destinations.retry_interval < EXCLUDED.retry_interval - """ - - txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) - - return + txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + def _set_destination_retry_timings_emulated( + self, txn, destination, failure_ts, retry_last_ts, retry_interval + ): self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 5a390ff2f6..d87ceec6da 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id_tuples: iterable of 2-tuple of user IDs. """ - def _add_users_who_share_room_txn(txn): - self.db_pool.simple_upsert_many_txn( - txn, - table="users_who_share_private_rooms", - key_names=["user_id", "other_user_id", "room_id"], - key_values=[ - (user_id, other_user_id, room_id) - for user_id, other_user_id in user_id_tuples - ], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_who_share_room", _add_users_who_share_room_txn + await self.db_pool.simple_upsert_many( + table="users_who_share_private_rooms", + key_names=["user_id", "other_user_id", "room_id"], + key_values=[ + (user_id, other_user_id, room_id) + for user_id, other_user_id in user_id_tuples + ], + value_names=(), + value_values=None, + desc="add_users_who_share_room", ) async def add_users_in_public_rooms( @@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_ids """ - def _add_users_in_public_rooms_txn(txn): - - self.db_pool.simple_upsert_many_txn( - txn, - table="users_in_public_rooms", - key_names=["user_id", "room_id"], - key_values=[(user_id, room_id) for user_id in user_ids], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_in_public_rooms", _add_users_in_public_rooms_txn + await self.db_pool.simple_upsert_many( + table="users_in_public_rooms", + key_names=["user_id", "room_id"], + key_values=[(user_id, room_id) for user_id in user_ids], + value_names=(), + value_values=None, + desc="add_users_in_public_rooms", ) async def delete_all_from_user_dir(self) -> None: diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 3d8da48f2d..02d71302ea 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -618,14 +618,7 @@ class _MultiWriterCtxManager: db_autocommit=True, ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. with self.id_gen._lock: - assert max(self.id_gen._current_positions.values(), default=0) < min( - self.stream_ids - ) - self.id_gen._unfinished_ids.update(self.stream_ids) if self.multiple_ids is None: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py new file mode 100644 index 0000000000..f728cd2cf2 --- /dev/null +++ b/synapse/util/caches/deferred_cache.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import threading +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast + +from prometheus_client import Gauge + +from twisted.internet import defer + +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches import register_cache +from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry + +cache_pending_metric = Gauge( + "synapse_util_caches_cache_pending", + "Number of lookups currently pending for this cache", + ["name"], +) + + +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + +class DeferredCache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + + __slots__ = ( + "cache", + "name", + "keylen", + "thread", + "metrics", + "_pending_deferred_cache", + ) + + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + """ + cache_type = TreeCache if tree else dict + + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. + self.cache = LruCache( + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, + size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, + ) + + self.name = name + self.keylen = keylen + self.thread = None # type: Optional[threading.Thread] + self.metrics = register_cache( + "cache", + name, + self.cache, + collect_callback=self._metrics_collection_callback, + ) + + @property + def max_entries(self): + return self.cache.max_size + + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) + + def _metrics_collection_callback(self): + cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): + """Looks the key up in the caches. + + Args: + key(tuple) + default: What is returned if key is not in the caches. If not + specified then function throws KeyError instead + callback(fn): Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + Either an ObservableDeferred or the result itself + """ + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: + self.metrics.inc_hits() + return val + + if update_metrics: + self.metrics.inc_misses() + + if default is _Sentinel.sentinel: + raise KeyError() + else: + return default + + def set( + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, + ) -> ObservableDeferred: + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + + callbacks = [callback] if callback else [] + self.check_thread() + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): + self.cache.set(key, result, entry.callbacks) + else: + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. + entry.invalidate() + + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + return observable + + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) + + def invalidate(self, key): + self.check_thread() + self.cache.pop(key, None) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. + entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. + if entry: + entry.invalidate() + + def invalidate_many(self, key: KT): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + self.cache.del_multi(key) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + + def invalidate_all(self): + self.check_thread() + self.cache.clear() + for entry in self._pending_deferred_cache.values(): + entry.invalidate() + self._pending_deferred_cache.clear() + + +class CacheEntry: + __slots__ = ["deferred", "callbacks", "invalidated"] + + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): + self.deferred = deferred + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 98b34f2223..1f43886804 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,25 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import functools import inspect import logging -import threading from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary -from prometheus_client import Gauge - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry - -from . import register_cache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) @@ -55,239 +48,6 @@ class _CachedFunction(Generic[F]): __call__ = None # type: F -cache_pending_metric = Gauge( - "synapse_util_caches_cache_pending", - "Number of lookups currently pending for this cache", - ["name"], -) - -_CacheSentinel = object() - - -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] - - def __init__(self, deferred, callbacks): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self): - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() - - -class Cache: - __slots__ = ( - "cache", - "name", - "keylen", - "thread", - "metrics", - "_pending_deferred_cache", - ) - - def __init__( - self, - name: str, - max_entries: int = 1000, - keylen: int = 1, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - ): - """ - Args: - name: The name of the cache - max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key - tree: Use a TreeCache instead of a dict as the underlying cache type - iterable: If True, count each item in the cached object as an entry, - rather than each cached object - apply_cache_factor_from_config: Whether cache factors specified in the - config file affect `max_entries` - - Returns: - Cache - """ - cache_type = TreeCache if tree else dict - self._pending_deferred_cache = cache_type() - - self.cache = LruCache( - max_size=max_entries, - keylen=keylen, - cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, - apply_cache_factor_from_config=apply_cache_factor_from_config, - ) - - self.name = name - self.keylen = keylen - self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) - - @property - def max_entries(self): - return self.cache.max_size - - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): - """Looks the key up in the caches. - - Args: - key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead - callback(fn): Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics - - Returns: - Either an ObservableDeferred or the raw result - """ - callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred - - val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) - if val is not _CacheSentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _CacheSentinel: - raise KeyError() - else: - return default - - def set(self, key, value, callback=None): - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] - self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) - - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() - - self._pending_deferred_cache[key] = entry - - def compare_and_pop(): - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result): - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail): - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) - return observable - - def prefill(self, key, value, callback=None): - callbacks = [callback] if callback else [] - self.cache.set(key, value, callbacks=callbacks) - - def invalidate(self, key): - self.check_thread() - self.cache.pop(key, None) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry - # in self.cache. - entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. - if entry: - entry.invalidate() - - def invalidate_many(self, key): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): - entry.invalidate() - - def invalidate_all(self): - self.check_thread() - self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() - self._pending_deferred_cache.clear() - - class _CacheDescriptorBase: def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig @@ -390,13 +150,13 @@ class CacheDescriptor(_CacheDescriptorBase): self.iterable = iterable def __get__(self, obj, owner): - cache = Cache( + cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) + ) # type: DeferredCache[Tuple, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into @@ -640,9 +400,9 @@ class _CacheContext: _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None self._cache = cache self._cache_key = cache_key @@ -651,7 +411,9 @@ class _CacheContext: self._cache.invalidate(self._cache_key) @classmethod - def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + def get_instance( + cls, cache, cache_key + ): # type: (DeferredCache, CacheKey) -> _CacheContext """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4bc1a67b58..33eae2b7c4 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -64,7 +64,8 @@ class LruCache: Args: max_size: The maximum amount of entries the cache can hold - keylen: The length of the tuple used as the cache key + keylen: The length of the tuple used as the cache key. Ignored unless + cache_type is `TreeCache`. cache_type (type): type of underlying cache to be used. Typically one of dict diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 3e180cafd3..6ce2a3d12b 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -34,7 +34,7 @@ class TTLCache: self._data = {} # the _CacheEntries, sorted by expiry time - self._expiry_list = SortedList() + self._expiry_list = SortedList() # type: SortedList[_CacheEntry] self._timer = timer diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 641093d349..4a301b84e1 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -22,7 +22,7 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + http_client=None, homeserver_to_use=GenericWorkerServer ) return hs diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 0f016c32eb..c2b10d2c70 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -26,7 +26,7 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + http_client=None, homeserver_to_use=GenericWorkerServer ) return hs @@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=SynapseHomeServer + http_client=None, homeserver_to_use=SynapseHomeServer ) return hs diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 68a4caabbf..2acb8b7603 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -60,7 +60,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events # txn made and saved + service=service, events=events, ephemeral=[] # txn made and saved ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -81,7 +81,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events # txn made and saved + service=service, events=events, ephemeral=[] # txn made and saved ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -106,7 +106,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events + service=service, events=events, ephemeral=[] ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -202,26 +202,28 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() - self.queuer.enqueue(service, event) - self.txn_ctrl.send.assert_called_once_with(service, [event]) + self.queuer.enqueue_event(service, event) + self.txn_ctrl.send.assert_called_once_with(service, [event], []) def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d)) + self.txn_ctrl.send = Mock( + side_effect=lambda x, y, z: make_deferred_yieldable(d) + ) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") event3 = Mock(event_id="third") # Send an event and don't resolve it just yet. - self.queuer.enqueue(service, event) + self.queuer.enqueue_event(service, event) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue(service, event2) - self.queuer.enqueue(service, event3) - self.txn_ctrl.send.assert_called_with(service, [event]) + self.queuer.enqueue_event(service, event2) + self.queuer.enqueue_event(service, event3) + self.txn_ctrl.send.assert_called_with(service, [event], []) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3]) + self.txn_ctrl.send.assert_called_with(service, [event2, event3], []) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -239,21 +241,58 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y): + def do_send(x, y, z): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent - self.queuer.enqueue(srv1, srv_1_event) - self.queuer.enqueue(srv1, srv_1_event2) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event]) - self.queuer.enqueue(srv2, srv_2_event) - self.queuer.enqueue(srv2, srv_2_event2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event]) + self.queuer.enqueue_event(srv1, srv_1_event) + self.queuer.enqueue_event(srv1, srv_1_event2) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], []) + self.queuer.enqueue_event(srv2, srv_2_event) + self.queuer.enqueue_event(srv2, srv_2_event2) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], []) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2]) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) self.assertEquals(3, self.txn_ctrl.send.call_count) + + def test_send_single_ephemeral_no_queue(self): + # Expect the event to be sent immediately. + service = Mock(id=4, name="service") + event_list = [Mock(name="event")] + self.queuer.enqueue_ephemeral(service, event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + + def test_send_multiple_ephemeral_no_queue(self): + # Expect the event to be sent immediately. + service = Mock(id=4, name="service") + event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] + self.queuer.enqueue_ephemeral(service, event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + + def test_send_single_ephemeral_with_queue(self): + d = defer.Deferred() + self.txn_ctrl.send = Mock( + side_effect=lambda x, y, z: make_deferred_yieldable(d) + ) + service = Mock(id=4) + event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] + event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")] + event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")] + + # Send an event and don't resolve it just yet. + self.queuer.enqueue_ephemeral(service, event_list_1) + # Send more events: expect send() to NOT be called multiple times. + self.queuer.enqueue_ephemeral(service, event_list_2) + self.queuer.enqueue_ephemeral(service, event_list_3) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1) + self.assertEquals(1, self.txn_ctrl.send.call_count) + # Resolve txn_ctrl.send + d.callback(service) + # Expect the queued events to be sent + self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) + self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 2a0b7c1b56..ee4f3da31c 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -18,6 +18,7 @@ from mock import Mock from twisted.internet import defer from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.types import RoomStreamToken from tests.test_utils import make_awaitable from tests.utils import MockClock @@ -61,7 +62,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred(self.handler.notify_interested_services(0)) + yield defer.ensureDeferred( + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + ) self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) @@ -80,7 +83,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred(self.handler.notify_interested_services(0)) + yield defer.ensureDeferred( + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + ) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @defer.inlineCallbacks @@ -97,7 +102,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred(self.handler.notify_interested_services(0)) + yield defer.ensureDeferred( + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + ) self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 64e28bc639..9f6f21a6e2 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -66,7 +66,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase): "sender": self.requester.user.to_string(), "content": {"msgtype": "m.text", "body": random_string(5)}, }, - token_id=self.token_id, txn_id=txn_id, ) ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 914c82e7a8..8ed67640f8 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.store.get_latest_event_ids_in_room(room_id) ) - event = self.get_success(builder.build(prev_event_ids)) + event = self.get_success(builder.build(prev_event_ids, None)) self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 81ea985b9f..093e2faac7 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -59,7 +59,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( http_client=None, - homeserverToUse=GenericWorkerServer, + homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, ) @@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config.update(extra_config) worker_hs = self.setup_test_homeserver( - homeserverToUse=GenericWorkerServer, + homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, **kwargs diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 23be1167a3..1853667558 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) return hs diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 9c4a9c3563..779745ae9d 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): } builder = factory.for_room_version(room_version, event_dict) - join_event = self.get_success(builder.build(prev_event_ids)) + join_event = self.get_success(builder.build(prev_event_ids, None)) self.get_success(federation.on_send_join_request(remote_server, join_event)) self.replicate() diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 6068d14905..82cf033d4e 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -14,8 +14,12 @@ # limitations under the License. import logging +from mock import patch + +from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.utils import USE_POSTGRES_FOR_TESTS @@ -36,6 +40,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, + sync.register_servlets, ] def prepare(self, reactor, clock, hs): @@ -43,6 +48,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.other_user_id = self.register_user("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass") + self.room_creator = self.hs.get_room_creation_handler() + self.store = hs.get_datastore() + def default_config(self): conf = super().default_config() conf["redis"] = {"enabled": "true"} @@ -53,6 +61,29 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): } return conf + def _create_room(self, room_id: str, user_id: str, tok: str): + """Create a room with given room_id + """ + + # We control the room ID generation by patching out the + # `_generate_room_id` method + async def generate_room( + creator_id: str, is_public: bool, room_version: RoomVersion + ): + await self.store.store_room( + room_id=room_id, + room_creator_user_id=creator_id, + is_public=is_public, + room_version=room_version, + ) + return room_id + + with patch( + "synapse.handlers.room.RoomCreationHandler._generate_room_id" + ) as mock: + mock.side_effect = generate_room + self.helper.create_room_as(user_id, tok=tok) + def test_basic(self): """Simple test to ensure that multiple rooms can be created and joined, and that different rooms get handled by different instances. @@ -100,3 +131,189 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.assertTrue(persisted_on_1) self.assertTrue(persisted_on_2) + + def test_vector_clock_token(self): + """Tests that using a stream token with a vector clock component works + correctly with basic /sync and /messages usage. + """ + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker1"}, + ) + + worker_hs2 = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker2"}, + ) + + sync_hs = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "sync"}, + ) + + # Specially selected room IDs that get persisted on different workers. + room_id1 = "!foo:test" + room_id2 = "!baz:test" + + self.assertEqual( + self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1" + ) + self.assertEqual( + self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2" + ) + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + store = self.hs.get_datastore() + + # Create two room on the different workers. + self._create_room(room_id1, user_id, access_token) + self._create_room(room_id2, user_id, access_token) + + # The other user joins + self.helper.join( + room=room_id1, user=self.other_user_id, tok=self.other_access_token + ) + self.helper.join( + room=room_id2, user=self.other_user_id, tok=self.other_access_token + ) + + # Do an initial sync so that we're up to date. + request, channel = self.make_request("GET", "/sync", access_token=access_token) + self.render_on_worker(sync_hs, request) + next_batch = channel.json_body["next_batch"] + + # We now gut wrench into the events stream MultiWriterIdGenerator on + # worker2 to mimic it getting stuck persisting an event. This ensures + # that when we send an event on worker1 we end up in a state where + # worker2 events stream position lags that on worker1, resulting in a + # RoomStreamToken with a non-empty instance map component. + # + # Worker2's event stream position will not advance until we call + # __aexit__ again. + actx = worker_hs2.get_datastore()._stream_id_gen.get_next() + self.get_success(actx.__aenter__()) + + response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token) + first_event_in_room1 = response["event_id"] + + # Assert that the current stream token has an instance map component, as + # we are trying to test vector clock tokens. + room_stream_token = store.get_room_max_token() + self.assertNotEqual(len(room_stream_token.instance_map), 0) + + # Check that syncing still gets the new event, despite the gap in the + # stream IDs. + request, channel = self.make_request( + "GET", "/sync?since={}".format(next_batch), access_token=access_token + ) + self.render_on_worker(sync_hs, request) + + # We should only see the new event and nothing else + self.assertIn(room_id1, channel.json_body["rooms"]["join"]) + self.assertNotIn(room_id2, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"] + self.assertListEqual( + [first_event_in_room1], [event["event_id"] for event in events] + ) + + # Get the next batch and makes sure its a vector clock style token. + vector_clock_token = channel.json_body["next_batch"] + self.assertTrue(vector_clock_token.startswith("m")) + + # Now that we've got a vector clock token we finish the fake persisting + # an event we started above. + self.get_success(actx.__aexit__(None, None, None)) + + # Now try and send an event to the other rooom so that we can test that + # the vector clock style token works as a `since` token. + response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) + first_event_in_room2 = response["event_id"] + + request, channel = self.make_request( + "GET", + "/sync?since={}".format(vector_clock_token), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + + self.assertNotIn(room_id1, channel.json_body["rooms"]["join"]) + self.assertIn(room_id2, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"] + self.assertListEqual( + [first_event_in_room2], [event["event_id"] for event in events] + ) + + next_batch = channel.json_body["next_batch"] + + # We also want to test that the vector clock style token works with + # pagination. We do this by sending a couple of new events into the room + # and syncing again to get a prev_batch token for each room, then + # paginating from there back to the vector clock token. + self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) + self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) + + request, channel = self.make_request( + "GET", "/sync?since={}".format(next_batch), access_token=access_token + ) + self.render_on_worker(sync_hs, request) + + prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][ + "prev_batch" + ] + prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][ + "prev_batch" + ] + + # Paginating back in the first room should not produce any results, as + # no events have happened in it. This tests that we are correctly + # filtering results based on the vector clock portion. + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=b".format( + room_id1, prev_batch1, vector_clock_token + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertListEqual([], channel.json_body["chunk"]) + + # Paginating back on the second room should produce the first event + # again. This tests that pagination isn't completely broken. + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=b".format( + room_id2, prev_batch2, vector_clock_token + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertEqual( + channel.json_body["chunk"][0]["event_id"], first_event_in_room2 + ) + + # Paginating forwards should give the same results + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=f".format( + room_id1, vector_clock_token, prev_batch1 + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertListEqual([], channel.json_body["chunk"]) + + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=f".format( + room_id2, vector_clock_token, prev_batch2, + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertEqual( + channel.json_body["chunk"][0]["event_id"], first_event_in_room2 + ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index f5afed017c..8e69b1e9cc 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,82 +20,11 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.descriptors import cached from tests import unittest -class CacheTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - self.cache = Cache("test") - - def test_empty(self): - failed = False - try: - self.cache.get("foo") - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_hit(self): - self.cache.prefill("foo", 123) - - self.assertEquals(self.cache.get("foo"), 123) - - def test_invalidate(self): - self.cache.prefill(("foo",), 123) - self.cache.invalidate(("foo",)) - - failed = False - try: - self.cache.get(("foo",)) - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_eviction(self): - cache = Cache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - cache.prefill(3, "three") # 1 will be evicted - - failed = False - try: - cache.get(1) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(2) - cache.get(3) - - def test_eviction_lru(self): - cache = Cache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - - # Now access 1 again, thus causing 2 to be least-recently used - cache.get(1) - - cache.prefill(3, "three") - - failed = False - try: - cache.get(2) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(1) - cache.get(3) - - class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index c905a38930..c5c7987349 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) @@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9645, events) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): events = [Mock(event_id="e1"), Mock(event_id="e2")] yield self._set_last_txn(service.id, 9643) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(self.as_list[3]["id"], 9643, events) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 1ea35d60c1..d4f9e809db 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): self._event_id = event_id @defer.inlineCallbacks - def build(self, prev_event_ids): + def build(self, prev_event_ids, auth_event_ids): built_event = yield defer.ensureDeferred( - self._base_builder.build(prev_event_ids) + self._base_builder.build(prev_event_ids, auth_event_ids) ) built_event._event_id = self._event_id diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f5f63d8ed6..759e4cd048 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest @@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase): Caches produce metrics reflecting their state when scraped. """ CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = Cache(CACHE_NAME, max_entries=777) + cache = DeferredCache(CACHE_NAME, max_entries=777) items = { x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") diff --git a/tests/unittest.py b/tests/unittest.py index 6c1661c92c..040b126a27 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Optional, Tuple, Type, TypeVar, Union +from typing import Optional, Tuple, Type, TypeVar, Union, overload from mock import Mock, patch @@ -364,6 +364,36 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ + # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest + # when the `request` arg isn't given, so we define an explicit override to + # cover that case. + @overload + def make_request( + self, + method: Union[bytes, str], + path: Union[bytes, str], + content: Union[bytes, dict] = b"", + access_token: Optional[str] = None, + shorthand: bool = True, + federation_auth_origin: str = None, + content_is_form: bool = False, + ) -> Tuple[SynapseRequest, FakeChannel]: + ... + + @overload + def make_request( + self, + method: Union[bytes, str], + path: Union[bytes, str], + content: Union[bytes, dict] = b"", + access_token: Optional[str] = None, + request: Type[T] = SynapseRequest, + shorthand: bool = True, + federation_auth_origin: str = None, + content_is_form: bool = False, + ) -> Tuple[T, FakeChannel]: + ... + def make_request( self, method: Union[bytes, str], diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py new file mode 100644 index 0000000000..9717be56b6 --- /dev/null +++ b/tests/util/caches/test_deferred_cache.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +from twisted.internet import defer + +from synapse.util.caches.deferred_cache import DeferredCache + + +class DeferredCacheTestCase(unittest.TestCase): + def test_empty(self): + cache = DeferredCache("test") + failed = False + try: + cache.get("foo") + except KeyError: + failed = True + + self.assertTrue(failed) + + def test_hit(self): + cache = DeferredCache("test") + cache.prefill("foo", 123) + + self.assertEquals(cache.get("foo"), 123) + + def test_invalidate(self): + cache = DeferredCache("test") + cache.prefill(("foo",), 123) + cache.invalidate(("foo",)) + + failed = False + try: + cache.get(("foo",)) + except KeyError: + failed = True + + self.assertTrue(failed) + + def test_invalidate_all(self): + cache = DeferredCache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return observable deferreds + self.assertFalse(cache.get("key1").has_called()) + self.assertFalse(cache.get("key2").has_called()) + + # let one of the lookups complete + d2.callback("result2") + + # for now at least, the cache will return real results rather than an + # observabledeferred + self.assertEqual(cache.get("key2"), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should return none + self.assertIsNone(cache.get("key1", None)) + self.assertIsNone(cache.get("key2", None)) + + # both callbacks should have been callbacked + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") + + # letting the other lookup complete should do nothing + d1.callback("result1") + self.assertIsNone(cache.get("key1", None)) + + def test_eviction(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + cache.prefill(3, "three") # 1 will be evicted + + failed = False + try: + cache.get(1) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(2) + cache.get(3) + + def test_eviction_lru(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + + # Now access 1 again, thus causing 2 to be least-recently used + cache.get(1) + + cache.prefill(3, "three") + + failed = False + try: + cache.get(2) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(1) + cache.get(3) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 677e925477..3d1f960869 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from functools import partial import mock @@ -42,49 +41,6 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class CacheTestCase(unittest.TestCase): - def test_invalidate_all(self): - cache = descriptors.Cache("testcache") - - callback_record = [False, False] - - def record_callback(idx): - callback_record[idx] = True - - # add a couple of pending entries - d1 = defer.Deferred() - cache.set("key1", d1, partial(record_callback, 0)) - - d2 = defer.Deferred() - cache.set("key2", d2, partial(record_callback, 1)) - - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) - - # let one of the lookups complete - d2.callback("result2") - - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") - - # now do the invalidation - cache.invalidate_all() - - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) - - # both callbacks should have been callbacked - self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") - self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") - - # letting the other lookup complete should do nothing - d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) - - class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): diff --git a/tests/utils.py b/tests/utils.py index 0c09f5457f..acec74e9e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,6 +21,7 @@ import time import uuid import warnings from inspect import getcallargs +from typing import Type from urllib import parse as urlparse from mock import Mock, patch @@ -194,8 +195,8 @@ def setup_test_homeserver( name="test", config=None, reactor=None, - homeserverToUse=TestHomeServer, - **kargs + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -218,8 +219,8 @@ def setup_test_homeserver( config.ldap_enabled = False - if "clock" not in kargs: - kargs["clock"] = MockClock() + if "clock" not in kwargs: + kwargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex @@ -264,18 +265,20 @@ def setup_test_homeserver( cur.close() db_conn.close() - hs = homeserverToUse( - name, - config=config, - version_string="Synapse/tests", - tls_server_context_factory=Mock(), - tls_client_options_factory=Mock(), - reactor=reactor, - **kargs + hs = homeserver_to_use( + name, config=config, version_string="Synapse/tests", reactor=reactor, ) + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": + if homeserver_to_use == TestHomeServer: hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): @@ -339,7 +342,7 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash - fed = kargs.get("resource_for_federation", None) + fed = kwargs.get("resource_for_federation", None) if fed: register_federation_servlets(hs, fed) |