diff --git a/CHANGES.md b/CHANGES.md
index 75dc5fa893..38a0814bbf 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,27 @@
+Synapse 1.21.2 (2020-10-15)
+===========================
+
+Debian packages and Docker images have been rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below.
+
+Security advisory
+-----------------
+
+* HTML pages served via Synapse were vulnerable to cross-site scripting (XSS)
+ attacks. All server administrators are encouraged to upgrade.
+ ([\#8444](https://github.com/matrix-org/synapse/pull/8444))
+ ([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891))
+
+ This fix was originally included in v1.21.0 but was missing a security advisory.
+
+ This was reported by [Denis Kasak](https://github.com/dkasak).
+
+Bugfixes
+--------
+
+- Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530))
+- An updated version of the authlib dependency is included in the Docker and Debian images to fix an issue using OpenID Connect. See [\#8534](https://github.com/matrix-org/synapse/issues/8534) for details.
+
+
Synapse 1.21.1 (2020-10-13)
===========================
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 524f82433d..c17e3b2399 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -63,6 +63,10 @@ run-time:
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
```
+You can also provided the `-d` option, which will lint the files that have been
+changed since the last git commit. This will often be significantly faster than
+linting the whole codebase.
+
Before pushing new changes, ensure they don't produce linting errors. Commit any
files that were corrected.
diff --git a/changelog.d/8437.feature b/changelog.d/8437.feature
new file mode 100644
index 0000000000..4abcccb326
--- /dev/null
+++ b/changelog.d/8437.feature
@@ -0,0 +1 @@
+Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.
diff --git a/changelog.d/8472.misc b/changelog.d/8472.misc
new file mode 100644
index 0000000000..880f3f5e14
--- /dev/null
+++ b/changelog.d/8472.misc
@@ -0,0 +1 @@
+Add `-d` option to `./scripts-dev/lint.sh` to lint files that have changed since the last git commit.
\ No newline at end of file
diff --git a/changelog.d/8488.misc b/changelog.d/8488.misc
new file mode 100644
index 0000000000..237cb3b311
--- /dev/null
+++ b/changelog.d/8488.misc
@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.
diff --git a/changelog.d/8503.misc b/changelog.d/8503.misc
new file mode 100644
index 0000000000..edb1be8aa8
--- /dev/null
+++ b/changelog.d/8503.misc
@@ -0,0 +1 @@
+Add user agent to user_daily_visits table.
diff --git a/changelog.d/8515.misc b/changelog.d/8515.misc
new file mode 100644
index 0000000000..1f8aa292d8
--- /dev/null
+++ b/changelog.d/8515.misc
@@ -0,0 +1 @@
+Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable.
diff --git a/changelog.d/8517.bugfix b/changelog.d/8517.bugfix
new file mode 100644
index 0000000000..1ab623c59f
--- /dev/null
+++ b/changelog.d/8517.bugfix
@@ -0,0 +1 @@
+Fix error code for `/profile/{userId}/displayname` to be `M_BAD_JSON`.
diff --git a/changelog.d/8526.doc b/changelog.d/8526.doc
new file mode 100644
index 0000000000..cbf48680c1
--- /dev/null
+++ b/changelog.d/8526.doc
@@ -0,0 +1 @@
+Added note about docker in manhole.md regarding which ip address to bind to. Contributed by @Maquis196.
diff --git a/changelog.d/8527.bugfix b/changelog.d/8527.bugfix
new file mode 100644
index 0000000000..727e0ba299
--- /dev/null
+++ b/changelog.d/8527.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.7.0 that could cause Synapse to insert values from non-state `m.room.retention` events into the `room_retention` database table.
diff --git a/changelog.d/8529.doc b/changelog.d/8529.doc
new file mode 100644
index 0000000000..6e710e6527
--- /dev/null
+++ b/changelog.d/8529.doc
@@ -0,0 +1 @@
+Document the new behaviour of the `allowed_lifetime_min` and `allowed_lifetime_max` settings in the room retention configuration.
diff --git a/changelog.d/8536.bugfix b/changelog.d/8536.bugfix
new file mode 100644
index 0000000000..8d238cc008
--- /dev/null
+++ b/changelog.d/8536.bugfix
@@ -0,0 +1 @@
+Fix not sending events over federation when using sharded event writers.
diff --git a/changelog.d/8537.misc b/changelog.d/8537.misc
new file mode 100644
index 0000000000..26309b5b93
--- /dev/null
+++ b/changelog.d/8537.misc
@@ -0,0 +1 @@
+Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`.
diff --git a/changelog.d/8542.misc b/changelog.d/8542.misc
new file mode 100644
index 0000000000..63149fd9b9
--- /dev/null
+++ b/changelog.d/8542.misc
@@ -0,0 +1 @@
+Improve database performance by executing more queries without starting transactions.
diff --git a/changelog.d/8547.misc b/changelog.d/8547.misc
new file mode 100644
index 0000000000..fafb1c8347
--- /dev/null
+++ b/changelog.d/8547.misc
@@ -0,0 +1 @@
+Enable mypy type checking for `synapse.util.caches`.
diff --git a/changelog.d/8548.misc b/changelog.d/8548.misc
new file mode 100644
index 0000000000..fba10bd731
--- /dev/null
+++ b/changelog.d/8548.misc
@@ -0,0 +1 @@
+Rename `Cache` to `DeferredCache`, to better reflect its purpose.
diff --git a/debian/changelog b/debian/changelog
index eeafd4f50a..8d873a4845 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,10 @@
+matrix-synapse-py3 (1.21.2) stable; urgency=medium
+
+ [ Synapse Packaging team ]
+ * New synapse release 1.21.2.
+
+ -- Synapse Packaging team <packages@matrix.org> Thu, 15 Oct 2020 09:23:27 -0400
+
matrix-synapse-py3 (1.21.1) stable; urgency=medium
[ Synapse Packaging team ]
diff --git a/docs/manhole.md b/docs/manhole.md
index 75b6ae40e0..37d1d7823c 100644
--- a/docs/manhole.md
+++ b/docs/manhole.md
@@ -5,22 +5,54 @@ The "manhole" allows server administrators to access a Python shell on a running
Synapse installation. This is a very powerful mechanism for administration and
debugging.
+**_Security Warning_**
+
+Note that this will give administrative access to synapse to **all users** with
+shell access to the server. It should therefore **not** be enabled in
+environments where untrusted users have shell access.
+
+***
+
To enable it, first uncomment the `manhole` listener configuration in
-`homeserver.yaml`:
+`homeserver.yaml`. The configuration is slightly different if you're using docker.
+
+#### Docker config
+
+If you are using Docker, set `bind_addresses` to `['0.0.0.0']` as shown:
```yaml
listeners:
- port: 9000
- bind_addresses: ['::1', '127.0.0.1']
+ bind_addresses: ['0.0.0.0']
type: manhole
```
-(`bind_addresses` in the above is important: it ensures that access to the
-manhole is only possible for local users).
+When using `docker run` to start the server, you will then need to change the command to the following to include the
+`manhole` port forwarding. The `-p 127.0.0.1:9000:9000` below is important: it
+ensures that access to the `manhole` is only possible for local users.
-Note that this will give administrative access to synapse to **all users** with
-shell access to the server. It should therefore **not** be enabled in
-environments where untrusted users have shell access.
+```bash
+docker run -d --name synapse \
+ --mount type=volume,src=synapse-data,dst=/data \
+ -p 8008:8008 \
+ -p 127.0.0.1:9000:9000 \
+ matrixdotorg/synapse:latest
+```
+
+#### Native config
+
+If you are not using docker, set `bind_addresses` to `['::1', '127.0.0.1']` as shown.
+The `bind_addresses` in the example below is important: it ensures that access to the
+`manhole` is only possible for local users).
+
+```yaml
+listeners:
+ - port: 9000
+ bind_addresses: ['::1', '127.0.0.1']
+ type: manhole
+```
+
+#### Accessing synapse manhole
Then restart synapse, and point an ssh client at port 9000 on localhost, using
the username `matrix`:
diff --git a/docs/message_retention_policies.md b/docs/message_retention_policies.md
index 1dd60bdad9..75d2028e17 100644
--- a/docs/message_retention_policies.md
+++ b/docs/message_retention_policies.md
@@ -136,24 +136,34 @@ the server's database.
### Lifetime limits
-**Note: this feature is mainly useful within a closed federation or on
-servers that don't federate, because there currently is no way to
-enforce these limits in an open federation.**
-
-Server admins can restrict the values their local users are allowed to
-use for both `min_lifetime` and `max_lifetime`. These limits can be
-defined as such in the `retention` section of the configuration file:
+Server admins can set limits on the values of `max_lifetime` to use when
+purging old events in a room. These limits can be defined as such in the
+`retention` section of the configuration file:
```yaml
allowed_lifetime_min: 1d
allowed_lifetime_max: 1y
```
-Here, `allowed_lifetime_min` is the lowest value a local user can set
-for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max`
-is the highest value. Both parameters are optional (e.g. setting
-`allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a
-minimum and no maximum).
+The limits are considered when running purge jobs. If necessary, the
+effective value of `max_lifetime` will be brought between
+`allowed_lifetime_min` and `allowed_lifetime_max` (inclusive).
+This means that, if the value of `max_lifetime` defined in the room's state
+is lower than `allowed_lifetime_min`, the value of `allowed_lifetime_min`
+will be used instead. Likewise, if the value of `max_lifetime` is higher
+than `allowed_lifetime_max`, the value of `allowed_lifetime_max` will be
+used instead.
+
+In the example above, we ensure Synapse never deletes events that are less
+than one day old, and that it always deletes events that are over a year
+old.
+
+If a default policy is set, and its `max_lifetime` value is lower than
+`allowed_lifetime_min` or higher than `allowed_lifetime_max`, the same
+process applies.
+
+Both parameters are optional; if one is omitted Synapse won't use it to
+adjust the effective value of `max_lifetime`.
Like other settings in this section, these parameters can be expressed
either as a duration or as a number of milliseconds.
diff --git a/mypy.ini b/mypy.ini
index f08fe992a4..b5db54ee3b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -15,6 +15,7 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
+ synapse/handlers/appservice.py,
synapse/handlers/account_data.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
@@ -64,9 +65,7 @@ files =
synapse/streams,
synapse/types.py,
synapse/util/async_helpers.py,
- synapse/util/caches/descriptors.py,
- synapse/util/caches/response_cache.py,
- synapse/util/caches/stream_change_cache.py,
+ synapse/util/caches,
synapse/util/metrics.py,
tests/replication,
tests/test_utils,
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 0647993658..f2b65a2105 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -1,4 +1,4 @@
-#!/bin/sh
+#!/bin/bash
#
# Runs linting scripts over the local Synapse checkout
# isort - sorts import statements
@@ -7,15 +7,90 @@
set -e
-if [ $# -ge 1 ]
-then
- files=$*
+usage() {
+ echo
+ echo "Usage: $0 [-h] [-d] [paths...]"
+ echo
+ echo "-d"
+ echo " Lint files that have changed since the last git commit."
+ echo
+ echo " If paths are provided and this option is set, both provided paths and those"
+ echo " that have changed since the last commit will be linted."
+ echo
+ echo " If no paths are provided and this option is not set, all files will be linted."
+ echo
+ echo " Note that paths with a file extension that is not '.py' will be excluded."
+ echo "-h"
+ echo " Display this help text."
+}
+
+USING_DIFF=0
+files=()
+
+while getopts ":dh" opt; do
+ case $opt in
+ d)
+ USING_DIFF=1
+ ;;
+ h)
+ usage
+ exit
+ ;;
+ \?)
+ echo "ERROR: Invalid option: -$OPTARG" >&2
+ usage
+ exit
+ ;;
+ esac
+done
+
+# Strip any options from the command line arguments now that
+# we've finished processing them
+shift "$((OPTIND-1))"
+
+if [ $USING_DIFF -eq 1 ]; then
+ # Check both staged and non-staged changes
+ for path in $(git diff HEAD --name-only); do
+ filename=$(basename "$path")
+ file_extension="${filename##*.}"
+
+ # If an extension is present, and it's something other than 'py',
+ # then ignore this file
+ if [[ -n ${file_extension+x} && $file_extension != "py" ]]; then
+ continue
+ fi
+
+ # Append this path to our list of files to lint
+ files+=("$path")
+ done
+fi
+
+# Append any remaining arguments as files to lint
+files+=("$@")
+
+if [[ $USING_DIFF -eq 1 ]]; then
+ # If we were asked to lint changed files, and no paths were found as a result...
+ if [ ${#files[@]} -eq 0 ]; then
+ # Then print and exit
+ echo "No files found to lint."
+ exit 0
+ fi
else
- files="synapse tests scripts-dev scripts contrib synctl"
+ # If we were not asked to lint changed files, and no paths were found as a result,
+ # then lint everything!
+ if [[ -z ${files+x} ]]; then
+ # Lint all source code files and directories
+ files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
+ fi
fi
-echo "Linting these locations: $files"
-isort $files
-python3 -m black $files
+echo "Linting these paths: ${files[*]}"
+echo
+
+# Print out the commands being run
+set -x
+
+isort "${files[@]}"
+python3 -m black "${files[@]}"
./scripts-dev/config-lint.sh
-flake8 $files
+flake8 "${files[@]}"
diff --git a/setup.py b/setup.py
index 926b1bc86f..08843fe2a3 100755
--- a/setup.py
+++ b/setup.py
@@ -15,12 +15,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import glob
import os
-from setuptools import setup, find_packages, Command
-import sys
+from setuptools import Command, find_packages, setup
here = os.path.abspath(os.path.dirname(__file__))
diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi
index 073b806d3c..fa307483fe 100644
--- a/stubs/sortedcontainers/__init__.pyi
+++ b/stubs/sortedcontainers/__init__.pyi
@@ -1,13 +1,12 @@
-from .sorteddict import (
- SortedDict,
- SortedKeysView,
- SortedItemsView,
- SortedValuesView,
-)
+from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView
+from .sortedlist import SortedKeyList, SortedList, SortedListWithKey
__all__ = [
"SortedDict",
"SortedKeysView",
"SortedItemsView",
"SortedValuesView",
+ "SortedKeyList",
+ "SortedList",
+ "SortedListWithKey",
]
diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi
new file mode 100644
index 0000000000..8f6086b3ff
--- /dev/null
+++ b/stubs/sortedcontainers/sortedlist.pyi
@@ -0,0 +1,177 @@
+# stub for SortedList. This is an exact copy of
+# https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi
+# (from https://github.com/grantjenks/python-sortedcontainers/pull/107)
+
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ MutableSequence,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+
+_T = TypeVar("_T")
+_SL = TypeVar("_SL", bound=SortedList)
+_SKL = TypeVar("_SKL", bound=SortedKeyList)
+_Key = Callable[[_T], Any]
+_Repr = Callable[[], str]
+
+def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ...
+
+class SortedList(MutableSequence[_T]):
+
+ DEFAULT_LOAD_FACTOR: int = ...
+ def __init__(
+ self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
+ ): ...
+ # NB: currently mypy does not honour return type, see mypy #3307
+ @overload
+ def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ...
+ @overload
+ def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ...
+ @overload
+ def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ...
+ @overload
+ def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ...
+ @property
+ def key(self) -> Optional[Callable[[_T], Any]]: ...
+ def _reset(self, load: int) -> None: ...
+ def clear(self) -> None: ...
+ def _clear(self) -> None: ...
+ def add(self, value: _T) -> None: ...
+ def _expand(self, pos: int) -> None: ...
+ def update(self, iterable: Iterable[_T]) -> None: ...
+ def _update(self, iterable: Iterable[_T]) -> None: ...
+ def discard(self, value: _T) -> None: ...
+ def remove(self, value: _T) -> None: ...
+ def _delete(self, pos: int, idx: int) -> None: ...
+ def _loc(self, pos: int, idx: int) -> int: ...
+ def _pos(self, idx: int) -> int: ...
+ def _build_index(self) -> None: ...
+ def __contains__(self, value: Any) -> bool: ...
+ def __delitem__(self, index: Union[int, slice]) -> None: ...
+ @overload
+ def __getitem__(self, index: int) -> _T: ...
+ @overload
+ def __getitem__(self, index: slice) -> List[_T]: ...
+ @overload
+ def _getitem(self, index: int) -> _T: ...
+ @overload
+ def _getitem(self, index: slice) -> List[_T]: ...
+ @overload
+ def __setitem__(self, index: int, value: _T) -> None: ...
+ @overload
+ def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ...
+ def __iter__(self) -> Iterator[_T]: ...
+ def __reversed__(self) -> Iterator[_T]: ...
+ def __len__(self) -> int: ...
+ def reverse(self) -> None: ...
+ def islice(
+ self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
+ ) -> Iterator[_T]: ...
+ def _islice(
+ self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
+ ) -> Iterator[_T]: ...
+ def irange(
+ self,
+ minimum: Optional[int] = ...,
+ maximum: Optional[int] = ...,
+ inclusive: Tuple[bool, bool] = ...,
+ reverse: bool = ...,
+ ) -> Iterator[_T]: ...
+ def bisect_left(self, value: _T) -> int: ...
+ def bisect_right(self, value: _T) -> int: ...
+ def bisect(self, value: _T) -> int: ...
+ def _bisect_right(self, value: _T) -> int: ...
+ def count(self, value: _T) -> int: ...
+ def copy(self: _SL) -> _SL: ...
+ def __copy__(self: _SL) -> _SL: ...
+ def append(self, value: _T) -> None: ...
+ def extend(self, values: Iterable[_T]) -> None: ...
+ def insert(self, index: int, value: _T) -> None: ...
+ def pop(self, index: int = ...) -> _T: ...
+ def index(
+ self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
+ ) -> int: ...
+ def __add__(self: _SL, other: Iterable[_T]) -> _SL: ...
+ def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ...
+ def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ...
+ def __mul__(self: _SL, num: int) -> _SL: ...
+ def __rmul__(self: _SL, num: int) -> _SL: ...
+ def __imul__(self: _SL, num: int) -> _SL: ...
+ def __eq__(self, other: Any) -> bool: ...
+ def __ne__(self, other: Any) -> bool: ...
+ def __lt__(self, other: Sequence[_T]) -> bool: ...
+ def __gt__(self, other: Sequence[_T]) -> bool: ...
+ def __le__(self, other: Sequence[_T]) -> bool: ...
+ def __ge__(self, other: Sequence[_T]) -> bool: ...
+ def __repr__(self) -> str: ...
+ def _check(self) -> None: ...
+
+class SortedKeyList(SortedList[_T]):
+ def __init__(
+ self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
+ ) -> None: ...
+ def __new__(
+ cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
+ ) -> SortedKeyList[_T]: ...
+ @property
+ def key(self) -> Callable[[_T], Any]: ...
+ def clear(self) -> None: ...
+ def _clear(self) -> None: ...
+ def add(self, value: _T) -> None: ...
+ def _expand(self, pos: int) -> None: ...
+ def update(self, iterable: Iterable[_T]) -> None: ...
+ def _update(self, iterable: Iterable[_T]) -> None: ...
+ # NB: Must be T to be safely passed to self.func, yet base class imposes Any
+ def __contains__(self, value: _T) -> bool: ... # type: ignore
+ def discard(self, value: _T) -> None: ...
+ def remove(self, value: _T) -> None: ...
+ def _delete(self, pos: int, idx: int) -> None: ...
+ def irange(
+ self,
+ minimum: Optional[int] = ...,
+ maximum: Optional[int] = ...,
+ inclusive: Tuple[bool, bool] = ...,
+ reverse: bool = ...,
+ ): ...
+ def irange_key(
+ self,
+ min_key: Optional[Any] = ...,
+ max_key: Optional[Any] = ...,
+ inclusive: Tuple[bool, bool] = ...,
+ reserve: bool = ...,
+ ): ...
+ def bisect_left(self, value: _T) -> int: ...
+ def bisect_right(self, value: _T) -> int: ...
+ def bisect(self, value: _T) -> int: ...
+ def bisect_key_left(self, key: Any) -> int: ...
+ def _bisect_key_left(self, key: Any) -> int: ...
+ def bisect_key_right(self, key: Any) -> int: ...
+ def _bisect_key_right(self, key: Any) -> int: ...
+ def bisect_key(self, key: Any) -> int: ...
+ def count(self, value: _T) -> int: ...
+ def copy(self: _SKL) -> _SKL: ...
+ def __copy__(self: _SKL) -> _SKL: ...
+ def index(
+ self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
+ ) -> int: ...
+ def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
+ def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
+ def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
+ def __mul__(self: _SKL, num: int) -> _SKL: ...
+ def __rmul__(self: _SKL, num: int) -> _SKL: ...
+ def __imul__(self: _SKL, num: int) -> _SKL: ...
+ def __repr__(self) -> str: ...
+ def _check(self) -> None: ...
+
+SortedListWithKey = SortedKeyList
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 722b53a67d..83b8e4897f 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.21.1"
+__version__ = "1.21.2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index d53181deb1..1b511890aa 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -790,10 +790,6 @@ class FederationSenderHandler:
send_queue.process_rows_for_federation(self.federation_sender, rows)
await self.update_token(token)
- # We also need to poke the federation sender when new events happen
- elif stream_name == "events":
- self.federation_sender.notify_new_events(token)
-
# ... and when new receipts happen
elif stream_name == ReceiptsStream.NAME:
await self._on_new_receipts(rows)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 13ec1f71a6..3862d9c08f 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,14 +14,15 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from synapse.api.constants import EventTypes
-from synapse.appservice.api import ApplicationServiceApi
-from synapse.types import GroupID, get_domain_from_id
+from synapse.events import EventBase
+from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
+ from synapse.appservice.api import ApplicationServiceApi
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -32,38 +33,6 @@ class ApplicationServiceState:
UP = "up"
-class AppServiceTransaction:
- """Represents an application service transaction."""
-
- def __init__(self, service, id, events):
- self.service = service
- self.id = id
- self.events = events
-
- async def send(self, as_api: ApplicationServiceApi) -> bool:
- """Sends this transaction using the provided AS API interface.
-
- Args:
- as_api: The API to use to send.
- Returns:
- True if the transaction was sent.
- """
- return await as_api.push_bulk(
- service=self.service, events=self.events, txn_id=self.id
- )
-
- async def complete(self, store: "DataStore") -> None:
- """Completes this transaction as successful.
-
- Marks this transaction ID on the application service and removes the
- transaction contents from the database.
-
- Args:
- store: The database store to operate on.
- """
- await store.complete_appservice_txn(service=self.service, txn_id=self.id)
-
-
class ApplicationService:
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@@ -91,6 +60,7 @@ class ApplicationService:
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
+ supports_ephemeral=False,
):
self.token = token
self.url = (
@@ -102,6 +72,7 @@ class ApplicationService:
self.namespaces = self._check_namespaces(namespaces)
self.id = id
self.ip_range_whitelist = ip_range_whitelist
+ self.supports_ephemeral = supports_ephemeral
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@@ -161,19 +132,21 @@ class ApplicationService:
raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces
- def _matches_regex(self, test_string, namespace_key):
+ def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
for regex_obj in self.namespaces[namespace_key]:
if regex_obj["regex"].match(test_string):
return regex_obj
return None
- def _is_exclusive(self, ns_key, test_string):
+ def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj:
return regex_obj["exclusive"]
return False
- async def _matches_user(self, event, store):
+ async def _matches_user(
+ self, event: Optional[EventBase], store: Optional["DataStore"] = None
+ ) -> bool:
if not event:
return False
@@ -188,14 +161,23 @@ class ApplicationService:
if not store:
return False
- does_match = await self._matches_user_in_member_list(event.room_id, store)
+ does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match
- @cached(num_args=1, cache_context=True)
- async def _matches_user_in_member_list(self, room_id, store, cache_context):
- member_list = await store.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
+ @cached(num_args=1)
+ async def matches_user_in_member_list(
+ self, room_id: str, store: "DataStore"
+ ) -> bool:
+ """Check if this service is interested a room based upon it's membership
+
+ Args:
+ room_id: The room to check.
+ store: The datastore to query.
+
+ Returns:
+ True if this service would like to know about this room.
+ """
+ member_list = await store.get_users_in_room(room_id)
# check joined member events
for user_id in member_list:
@@ -203,12 +185,14 @@ class ApplicationService:
return True
return False
- def _matches_room_id(self, event):
+ def _matches_room_id(self, event: EventBase) -> bool:
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
- async def _matches_aliases(self, event, store):
+ async def _matches_aliases(
+ self, event: EventBase, store: Optional["DataStore"] = None
+ ) -> bool:
if not store or not event:
return False
@@ -218,12 +202,15 @@ class ApplicationService:
return True
return False
- async def is_interested(self, event, store=None) -> bool:
+ async def is_interested(
+ self, event: EventBase, store: Optional["DataStore"] = None
+ ) -> bool:
"""Check if this service is interested in this event.
Args:
- event(Event): The event to check.
- store(DataStore)
+ event: The event to check.
+ store: The datastore to query.
+
Returns:
True if this service would like to know about this event.
"""
@@ -231,39 +218,66 @@ class ApplicationService:
if self._matches_room_id(event):
return True
+ # This will check the namespaces first before
+ # checking the store, so should be run before _matches_aliases
+ if await self._matches_user(event, store):
+ return True
+
+ # This will check the store, so should be run last
if await self._matches_aliases(event, store):
return True
- if await self._matches_user(event, store):
+ return False
+
+ @cached(num_args=1)
+ async def is_interested_in_presence(
+ self, user_id: UserID, store: "DataStore"
+ ) -> bool:
+ """Check if this service is interested a user's presence
+
+ Args:
+ user_id: The user to check.
+ store: The datastore to query.
+
+ Returns:
+ True if this service would like to know about presence for this user.
+ """
+ # Find all the rooms the sender is in
+ if self.is_interested_in_user(user_id.to_string()):
return True
+ room_ids = await store.get_rooms_for_user(user_id.to_string())
+ # Then find out if the appservice is interested in any of those rooms
+ for room_id in room_ids:
+ if await self.matches_user_in_member_list(room_id, store):
+ return True
return False
- def is_interested_in_user(self, user_id):
+ def is_interested_in_user(self, user_id: str) -> bool:
return (
- self._matches_regex(user_id, ApplicationService.NS_USERS)
+ bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
or user_id == self.sender
)
- def is_interested_in_alias(self, alias):
+ def is_interested_in_alias(self, alias: str) -> bool:
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
- def is_interested_in_room(self, room_id):
+ def is_interested_in_room(self, room_id: str) -> bool:
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
- def is_exclusive_user(self, user_id):
+ def is_exclusive_user(self, user_id: str) -> bool:
return (
self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender
)
- def is_interested_in_protocol(self, protocol):
+ def is_interested_in_protocol(self, protocol: str) -> bool:
return protocol in self.protocols
- def is_exclusive_alias(self, alias):
+ def is_exclusive_alias(self, alias: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
- def is_exclusive_room(self, room_id):
+ def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exclusive_user_regexes(self):
@@ -276,14 +290,14 @@ class ApplicationService:
if regex_obj["exclusive"]
]
- def get_groups_for_user(self, user_id):
+ def get_groups_for_user(self, user_id: str) -> Iterable[str]:
"""Get the groups that this user is associated with by this AS
Args:
- user_id (str): The ID of the user.
+ user_id: The ID of the user.
Returns:
- iterable[str]: an iterable that yields group_id strings.
+ An iterable that yields group_id strings.
"""
return (
regex_obj["group_id"]
@@ -291,7 +305,7 @@ class ApplicationService:
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
)
- def is_rate_limited(self):
+ def is_rate_limited(self) -> bool:
return self.rate_limited
def __str__(self):
@@ -300,3 +314,45 @@ class ApplicationService:
dict_copy["token"] = "<redacted>"
dict_copy["hs_token"] = "<redacted>"
return "ApplicationService: %s" % (dict_copy,)
+
+
+class AppServiceTransaction:
+ """Represents an application service transaction."""
+
+ def __init__(
+ self,
+ service: ApplicationService,
+ id: int,
+ events: List[EventBase],
+ ephemeral: List[JsonDict],
+ ):
+ self.service = service
+ self.id = id
+ self.events = events
+ self.ephemeral = ephemeral
+
+ async def send(self, as_api: "ApplicationServiceApi") -> bool:
+ """Sends this transaction using the provided AS API interface.
+
+ Args:
+ as_api: The API to use to send.
+ Returns:
+ True if the transaction was sent.
+ """
+ return await as_api.push_bulk(
+ service=self.service,
+ events=self.events,
+ ephemeral=self.ephemeral,
+ txn_id=self.id,
+ )
+
+ async def complete(self, store: "DataStore") -> None:
+ """Completes this transaction as successful.
+
+ Marks this transaction ID on the application service and removes the
+ transaction contents from the database.
+
+ Args:
+ store: The database store to operate on.
+ """
+ await store.complete_appservice_txn(service=self.service, txn_id=self.id)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e8f0793795..e366a982b8 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -14,12 +14,13 @@
# limitations under the License.
import logging
import urllib
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
+from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
@@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)
- async def push_bulk(self, service, events, txn_id=None):
+ async def push_bulk(
+ self,
+ service: "ApplicationService",
+ events: List[EventBase],
+ ephemeral: List[JsonDict],
+ txn_id: Optional[int] = None,
+ ):
if service.url is None:
return True
@@ -211,15 +218,19 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url
)
- txn_id = str(0)
- txn_id = str(txn_id)
+ txn_id = 0
+
+ uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
+
+ # Never send ephemeral events to appservices that do not support it
+ if service.supports_ephemeral:
+ body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
+ else:
+ body = {"events": events}
- uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
await self.put_json(
- uri=uri,
- json_body={"events": events},
- args={"access_token": service.hs_token},
+ uri=uri, json_body=body, args={"access_token": service.hs_token},
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 8eb8c6f51c..ad3c408519 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -49,10 +49,13 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
+from typing import List
-from synapse.appservice import ApplicationServiceState
+from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -82,8 +85,13 @@ class ApplicationServiceScheduler:
for service in services:
self.txn_ctrl.start_recoverer(service)
- def submit_event_for_as(self, service, event):
- self.queuer.enqueue(service, event)
+ def submit_event_for_as(self, service: ApplicationService, event: EventBase):
+ self.queuer.enqueue_event(service, event)
+
+ def submit_ephemeral_events_for_as(
+ self, service: ApplicationService, events: List[JsonDict]
+ ):
+ self.queuer.enqueue_ephemeral(service, events)
class _ServiceQueuer:
@@ -96,17 +104,15 @@ class _ServiceQueuer:
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
+ self.queued_ephemeral = {} # dict of {service_id: [events]}
# the appservices which currently have a transaction in flight
self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
- def enqueue(self, service, event):
- self.queued_events.setdefault(service.id, []).append(event)
-
+ def _start_background_request(self, service):
# start a sender for this appservice if we don't already have one
-
if service.id in self.requests_in_flight:
return
@@ -114,7 +120,15 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service
)
- async def _send_request(self, service):
+ def enqueue_event(self, service: ApplicationService, event: EventBase):
+ self.queued_events.setdefault(service.id, []).append(event)
+ self._start_background_request(service)
+
+ def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
+ self.queued_ephemeral.setdefault(service.id, []).extend(events)
+ self._start_background_request(service)
+
+ async def _send_request(self, service: ApplicationService):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
assert service.id not in self.requests_in_flight
@@ -123,10 +137,11 @@ class _ServiceQueuer:
try:
while True:
events = self.queued_events.pop(service.id, [])
- if not events:
+ ephemeral = self.queued_ephemeral.pop(service.id, [])
+ if not events and not ephemeral:
return
try:
- await self.txn_ctrl.send(service, events)
+ await self.txn_ctrl.send(service, events, ephemeral)
except Exception:
logger.exception("AS request failed")
finally:
@@ -158,9 +173,16 @@ class _TransactionController:
# for UTs
self.RECOVERER_CLASS = _Recoverer
- async def send(self, service, events):
+ async def send(
+ self,
+ service: ApplicationService,
+ events: List[EventBase],
+ ephemeral: List[JsonDict] = [],
+ ):
try:
- txn = await self.store.create_appservice_txn(service=service, events=events)
+ txn = await self.store.create_appservice_txn(
+ service=service, events=events, ephemeral=ephemeral
+ )
service_is_up = await self._is_service_up(service)
if service_is_up:
sent = await txn.send(self.as_api)
@@ -204,7 +226,7 @@ class _TransactionController:
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))
- async def _is_service_up(self, service):
+ async def _is_service_up(self, service: ApplicationService) -> bool:
state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 8ed3e24258..746fc3cc02 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -160,6 +160,8 @@ def _load_appservice(hostname, as_info, config_filename):
if as_info.get("ip_range_whitelist"):
ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
+ supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
+
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
@@ -168,6 +170,7 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
+ supports_ephemeral=supports_ephemeral,
protocols=protocols,
rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist,
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index b6c47be646..df4f950fec 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -97,32 +97,37 @@ class EventBuilder:
def is_state(self):
return self._state_key is not None
- async def build(self, prev_event_ids: List[str]) -> EventBase:
+ async def build(
+ self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
+ ) -> EventBase:
"""Transform into a fully signed and hashed event
Args:
prev_event_ids: The event IDs to use as the prev events
+ auth_event_ids: The event IDs to use as the auth events.
+ Should normally be set to None, which will cause them to be calculated
+ based on the room state at the prev_events.
Returns:
The signed and hashed event.
"""
-
- state_ids = await self._state.get_current_state_ids(
- self.room_id, prev_event_ids
- )
- auth_ids = self._auth.compute_auth_events(self, state_ids)
+ if auth_event_ids is None:
+ state_ids = await self._state.get_current_state_ids(
+ self.room_id, prev_event_ids
+ )
+ auth_event_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes(
- auth_ids
+ auth_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes(
prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else:
- auth_events = auth_ids
+ auth_events = auth_event_ids
prev_events = prev_event_ids
old_depth = await self._store.get_max_depth_of(prev_event_ids)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 9df35b54ba..5f9af8529b 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -83,6 +83,9 @@ class EventValidator:
Args:
event (FrozenEvent): The event to validate.
"""
+ if not event.is_state():
+ raise SynapseError(code=400, msg="must be a state event")
+
min_lifetime = event.content.get("min_lifetime")
max_lifetime = event.content.get("max_lifetime")
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 8e46957d15..5f1bf492c1 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -188,7 +188,7 @@ class FederationRemoteSendQueue:
for key in keys[:i]:
del self.edus[key]
- def notify_new_events(self, current_id):
+ def notify_new_events(self, max_token):
"""As per FederationSender"""
# We don't need to replicate this as it gets sent down a different
# stream.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index e33b29a42c..604cfd1935 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -40,7 +40,7 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import ReadReceipt
+from synapse.types import ReadReceipt, RoomStreamToken
from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
@@ -154,10 +154,15 @@ class FederationSender:
self._per_destination_queues[destination] = queue
return queue
- def notify_new_events(self, current_id: int) -> None:
+ def notify_new_events(self, max_token: RoomStreamToken) -> None:
"""This gets called when we have some new events we might want to
send out to other servers.
"""
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ current_id = max_token.stream
+
self._last_poked_id = max(current_id, self._last_poked_id)
if self._is_processing:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9d4e87dad6..07240d3a14 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Dict, List, Optional
from prometheus_client import Counter
@@ -21,12 +22,16 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
+from synapse.appservice import ApplicationService
+from synapse.events import EventBase
+from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import (
event_processing_loop_counter,
event_processing_loop_room_count,
)
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -43,19 +48,22 @@ class ApplicationServicesHandler:
self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
+ self.event_sources = hs.get_event_sources()
self.current_max = 0
self.is_processing = False
- async def notify_interested_services(self, current_id):
+ async def notify_interested_services(self, max_token: RoomStreamToken):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
prolonged length of time.
-
- Args:
- current_id(int): The current maximum ID.
"""
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ current_id = max_token.stream
+
services = self.store.get_app_services()
if not services or not self.notify_appservices:
return
@@ -79,7 +87,7 @@ class ApplicationServicesHandler:
if not events:
break
- events_by_room = {}
+ events_by_room = {} # type: Dict[str, List[EventBase]]
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@@ -158,6 +166,104 @@ class ApplicationServicesHandler:
finally:
self.is_processing = False
+ async def notify_interested_services_ephemeral(
+ self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [],
+ ):
+ """This is called by the notifier in the background
+ when a ephemeral event handled by the homeserver.
+
+ This will determine which appservices
+ are interested in the event, and submit them.
+
+ Events will only be pushed to appservices
+ that have opted into ephemeral events
+
+ Args:
+ stream_key: The stream the event came from.
+ new_token: The latest stream token
+ users: The user(s) involved with the event.
+ """
+ services = [
+ service
+ for service in self.store.get_app_services()
+ if service.supports_ephemeral
+ ]
+ if not services or not self.notify_appservices:
+ return
+ logger.info("Checking interested services for %s" % (stream_key))
+ with Measure(self.clock, "notify_interested_services_ephemeral"):
+ for service in services:
+ # Only handle typing if we have the latest token
+ if stream_key == "typing_key" and new_token is not None:
+ events = await self._handle_typing(service, new_token)
+ if events:
+ self.scheduler.submit_ephemeral_events_for_as(service, events)
+ # We don't persist the token for typing_key for performance reasons
+ elif stream_key == "receipt_key":
+ events = await self._handle_receipts(service)
+ if events:
+ self.scheduler.submit_ephemeral_events_for_as(service, events)
+ await self.store.set_type_stream_id_for_appservice(
+ service, "read_receipt", new_token
+ )
+ elif stream_key == "presence_key":
+ events = await self._handle_presence(service, users)
+ if events:
+ self.scheduler.submit_ephemeral_events_for_as(service, events)
+ await self.store.set_type_stream_id_for_appservice(
+ service, "presence", new_token
+ )
+
+ async def _handle_typing(self, service: ApplicationService, new_token: int):
+ typing_source = self.event_sources.sources["typing"]
+ # Get the typing events from just before current
+ typing, _ = await typing_source.get_new_events_as(
+ service=service,
+ # For performance reasons, we don't persist the previous
+ # token in the DB and instead fetch the latest typing information
+ # for appservices.
+ from_key=new_token - 1,
+ )
+ return typing
+
+ async def _handle_receipts(self, service: ApplicationService):
+ from_key = await self.store.get_type_stream_id_for_appservice(
+ service, "read_receipt"
+ )
+ receipts_source = self.event_sources.sources["receipt"]
+ receipts, _ = await receipts_source.get_new_events_as(
+ service=service, from_key=from_key
+ )
+ return receipts
+
+ async def _handle_presence(
+ self, service: ApplicationService, users: Collection[UserID]
+ ):
+ events = [] # type: List[JsonDict]
+ presence_source = self.event_sources.sources["presence"]
+ from_key = await self.store.get_type_stream_id_for_appservice(
+ service, "presence"
+ )
+ for user in users:
+ interested = await service.is_interested_in_presence(user, self.store)
+ if not interested:
+ continue
+ presence_events, _ = await presence_source.get_new_events(
+ user=user, service=service, from_key=from_key,
+ )
+ time_now = self.clock.time_msec()
+ presence_events = [
+ {
+ "type": "m.presence",
+ "sender": event.user_id,
+ "content": format_user_presence_state(
+ event, time_now, include_user_id=False
+ ),
+ }
+ for event in presence_events
+ ]
+ events = events + presence_events
+
async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.
@@ -220,7 +326,7 @@ class ApplicationServicesHandler:
async def get_3pe_protocols(self, only_protocol=None):
services = self.store.get_app_services()
- protocols = {}
+ protocols = {} # type: Dict[str, List[JsonDict]]
# Collect up all the individual protocol responses out of the ASes
for s in services:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0c6aec347e..7f00805a91 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -437,9 +437,9 @@ class EventCreationHandler:
self,
requester: Requester,
event_dict: dict,
- token_id: Optional[str] = None,
txn_id: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
+ auth_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
@@ -453,13 +453,18 @@ class EventCreationHandler:
Args:
requester
event_dict: An entire event
- token_id
txn_id
prev_event_ids:
the forward extremities to use as the prev_events for the
new event.
If None, they will be requested from the database.
+
+ auth_event_ids:
+ The event ids to use as the auth_events for the new event.
+ Should normally be left as None, which will cause them to be calculated
+ based on the room state at the prev_events.
+
require_consent: Whether to check if the requester has
consented to the privacy policy.
Raises:
@@ -511,14 +516,17 @@ class EventCreationHandler:
if require_consent and not is_exempt:
await self.assert_accepted_privacy_policy(requester)
- if token_id is not None:
- builder.internal_metadata.token_id = token_id
+ if requester.access_token_id is not None:
+ builder.internal_metadata.token_id = requester.access_token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = await self.create_new_client_event(
- builder=builder, requester=requester, prev_event_ids=prev_event_ids,
+ builder=builder,
+ requester=requester,
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
)
# In an ideal world we wouldn't need the second part of this condition. However,
@@ -726,7 +734,7 @@ class EventCreationHandler:
return event, event.internal_metadata.stream_ordering
event, context = await self.create_event(
- requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
+ requester, event_dict, txn_id=txn_id
)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
@@ -757,6 +765,7 @@ class EventCreationHandler:
builder: EventBuilder,
requester: Optional[Requester] = None,
prev_event_ids: Optional[List[str]] = None,
+ auth_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@@ -769,6 +778,11 @@ class EventCreationHandler:
If None, they will be requested from the database.
+ auth_event_ids:
+ The event ids to use as the auth_events for the new event.
+ Should normally be left as None, which will cause them to be calculated
+ based on the room state at the prev_events.
+
Returns:
Tuple of created event, context
"""
@@ -790,7 +804,9 @@ class EventCreationHandler:
builder.type == EventTypes.Create or len(prev_event_ids) > 0
), "Attempting to create an event with no prev_events"
- event = await builder.build(prev_event_ids=prev_event_ids)
+ event = await builder.build(
+ prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
+ )
context = await self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 7225923757..c242c409cf 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
+from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
-from synapse.types import ReadReceipt, get_domain_from_id
+from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -140,5 +142,36 @@ class ReceiptEventSource:
return (events, to_key)
+ async def get_new_events_as(
+ self, from_key: int, service: ApplicationService
+ ) -> Tuple[List[JsonDict], int]:
+ """Returns a set of new receipt events that an appservice
+ may be interested in.
+
+ Args:
+ from_key: the stream position at which events should be fetched from
+ service: The appservice which may be interested
+ """
+ from_key = int(from_key)
+ to_key = self.get_current_key()
+
+ if from_key == to_key:
+ return [], to_key
+
+ # We first need to fetch all new receipts
+ rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
+ from_key=from_key, to_key=to_key
+ )
+
+ # Then filter down to rooms that the AS can read
+ events = []
+ for room_id, event in rooms_to_events.items():
+ if not await service.matches_user_in_member_list(room_id, self.store):
+ continue
+
+ events.append(event)
+
+ return (events, to_key)
+
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 93ed51063a..ec300d8877 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -214,7 +214,6 @@ class RoomCreationHandler(BaseHandler):
"replacement_room": new_room_id,
},
},
- token_id=requester.access_token_id,
)
old_room_version = await self.store.get_room_version_id(old_room_id)
await self.auth.check_from_context(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0080eeaf8d..ec784030e9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,12 +17,10 @@ import abc
import logging
import random
from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
-
-from unpaddedbase64 import encode_base64
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
-from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -31,12 +29,8 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
-from synapse.api.room_versions import EventFormatVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
-from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
-from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
@@ -193,7 +187,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# For backwards compatibility:
"membership": membership,
},
- token_id=requester.access_token_id,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
require_consent=require_consent,
@@ -1133,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id = invite_event.room_id
target_user = invite_event.state_key
- room_version = await self.store.get_room_version(room_id)
content["membership"] = Membership.LEAVE
- # the auth events for the new event are the same as that of the invite, plus
- # the invite itself.
- #
- # the prev_events are just the invite.
- invite_hash = invite_event.event_id # type: Union[str, Tuple]
- if room_version.event_format == EventFormatVersions.V1:
- alg, h = compute_event_reference_hash(invite_event)
- invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
-
- auth_events = tuple(invite_event.auth_events) + (invite_hash,)
- prev_events = (invite_hash,)
-
- # we cap depth of generated events, to ensure that they are not
- # rejected by other servers (and so that they can be persisted in
- # the db)
- depth = min(invite_event.depth + 1, MAX_DEPTH)
-
event_dict = {
- "depth": depth,
- "auth_events": auth_events,
- "prev_events": prev_events,
"type": EventTypes.Member,
"room_id": room_id,
"sender": target_user,
@@ -1165,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
- event = create_local_event_from_event_dict(
- clock=self.clock,
- hostname=self.hs.hostname,
- signing_key=self.hs.signing_key,
- room_version=room_version,
- event_dict=event_dict,
+ # the auth events for the new event are the same as that of the invite, plus
+ # the invite itself.
+ #
+ # the prev_events are just the invite.
+ prev_event_ids = [invite_event.event_id]
+ auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+
+ event, context = await self.event_creation_handler.create_event(
+ requester,
+ event_dict,
+ txn_id=txn_id,
+ prev_event_ids=prev_event_ids,
+ auth_event_ids=auth_event_ids,
)
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
- if txn_id is not None:
- event.internal_metadata.txn_id = txn_id
- if requester.access_token_id is not None:
- event.internal_metadata.token_id = requester.access_token_id
-
- EventValidator().validate_new(event, self.config)
- context = await self.state_handler.compute_event_context(event)
- context.app_service = requester.app_service
result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)],
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a306631094..b527724bc4 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3cbfc2d780..d3692842e3 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
+from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -430,6 +430,33 @@ class TypingNotificationEventSource:
"content": {"user_ids": list(typing)},
}
+ async def get_new_events_as(
+ self, from_key: int, service: ApplicationService
+ ) -> Tuple[List[JsonDict], int]:
+ """Returns a set of new typing events that an appservice
+ may be interested in.
+
+ Args:
+ from_key: the stream position at which events should be fetched from
+ service: The appservice which may be interested
+ """
+ with Measure(self.clock, "typing.get_new_events_as"):
+ from_key = int(from_key)
+ handler = self.get_typing_handler()
+
+ events = []
+ for room_id in handler._room_serials.keys():
+ if handler._room_serials[room_id] <= from_key:
+ continue
+ if not await service.matches_user_in_member_list(
+ room_id, handler.store
+ ):
+ continue
+
+ events.append(self._make_event_for(room_id))
+
+ return (events, handler._latest_room_serial)
+
async def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 13adeed01e..2e993411b9 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -319,19 +319,35 @@ class Notifier:
)
if self.federation_sender:
- self.federation_sender.notify_new_events(max_room_stream_token.stream)
+ self.federation_sender.notify_new_events(max_room_stream_token)
async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
try:
await self.appservice_handler.notify_interested_services(
- max_room_stream_token.stream
+ max_room_stream_token
+ )
+ except Exception:
+ logger.exception("Error notifying application services of event")
+
+ async def _notify_app_services_ephemeral(
+ self,
+ stream_key: str,
+ new_token: Union[int, RoomStreamToken],
+ users: Collection[UserID] = [],
+ ):
+ try:
+ stream_token = None
+ if isinstance(new_token, int):
+ stream_token = new_token
+ await self.appservice_handler.notify_interested_services_ephemeral(
+ stream_key, stream_token, users
)
except Exception:
logger.exception("Error notifying application services of event")
async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
- await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
+ await self._pusher_pool.on_new_notifications(max_room_stream_token)
except Exception:
logger.exception("Error pusher pool of event")
@@ -367,6 +383,15 @@ class Notifier:
self.notify_replication()
+ # Notify appservices
+ run_as_background_process(
+ "_notify_app_services_ephemeral",
+ self._notify_app_services_ephemeral,
+ stream_key,
+ new_token,
+ users,
+ )
+
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 28bd8ab748..c6763971ee 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -18,6 +18,7 @@ import logging
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
@@ -91,7 +92,12 @@ class EmailPusher:
pass
self.timed_call = None
- def on_new_notifications(self, max_stream_ordering):
+ def on_new_notifications(self, max_token: RoomStreamToken):
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ max_stream_ordering = max_token.stream
+
if self.max_stream_ordering:
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 26706bf3e1..793d0db2d9 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
+from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools
@@ -114,7 +115,12 @@ class HttpPusher:
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, max_stream_ordering):
+ def on_new_notifications(self, max_token: RoomStreamToken):
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ max_stream_ordering = max_token.stream
+
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering or 0
)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 76150e117b..0080c68ce2 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -24,6 +24,7 @@ from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory
+from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
@@ -186,11 +187,16 @@ class PusherPool:
)
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
- async def on_new_notifications(self, max_stream_id: int):
+ async def on_new_notifications(self, max_token: RoomStreamToken):
if not self.pushers:
# nothing to do here.
return
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ max_stream_id = max_token.stream
+
if max_stream_id < self._last_room_stream_id_seen:
# Nothing to do
return
@@ -214,7 +220,7 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
- p.on_new_notifications(max_stream_id)
+ p.on_new_notifications(max_token)
except Exception:
logger.exception("Exception in pusher on_new_notifications")
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1f8dafe7ea..4b0ea0cc01 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.deferred_cache import DeferredCache
from ._base import BaseSlavedStore
@@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self.client_ip_last_seen = Cache(
+ self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
- )
+ ) # type: DeferredCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index b686cd671f..e7fcd2b1ff 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -59,7 +59,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
try:
new_name = content["displayname"]
except Exception:
- return 400, "Unable to parse name"
+ raise SynapseError(
+ code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
+ )
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
diff --git a/synapse/server.py b/synapse/server.py
index f921ee4b53..21a232bbd9 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -205,7 +205,13 @@ class HomeServer(metaclass=abc.ABCMeta):
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
- def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
+ def __init__(
+ self,
+ hostname: str,
+ config: HomeServerConfig,
+ reactor=None,
+ version_string="Synapse",
+ ):
"""
Args:
hostname : The hostname for the server.
@@ -236,11 +242,9 @@ class HomeServer(metaclass=abc.ABCMeta):
burst_count=config.rc_registration.burst_count,
)
- self.datastores = None # type: Optional[Databases]
+ self.version_string = version_string
- # Other kwargs are explicit dependencies
- for depname in kwargs:
- setattr(self, depname, kwargs[depname])
+ self.datastores = None # type: Optional[Databases]
def get_instance_id(self) -> str:
"""A unique ID for this synapse process instance.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0ba3a025cf..763722d6bc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -893,6 +893,12 @@ class DatabasePool:
attempts = 0
while True:
try:
+ # We can autocommit if we are going to use native upserts
+ autocommit = (
+ self.engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ )
+
return await self.runInteraction(
desc,
self.simple_upsert_txn,
@@ -901,6 +907,7 @@ class DatabasePool:
values,
insertion_values,
lock=lock,
+ db_autocommit=autocommit,
)
except self.engine.module.IntegrityError as e:
attempts += 1
@@ -1063,6 +1070,43 @@ class DatabasePool:
)
txn.execute(sql, list(allvalues.values()))
+ async def simple_upsert_many(
+ self,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ desc: str,
+ ) -> None:
+ """
+ Upsert, many times.
+
+ Args:
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
+ """
+
+ # We can autocommit if we are going to use native upserts
+ autocommit = (
+ self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
+ )
+
+ return await self.runInteraction(
+ desc,
+ self.simple_upsert_many_txn,
+ table,
+ key_names,
+ key_values,
+ value_names,
+ value_values,
+ db_autocommit=autocommit,
+ )
+
def simple_upsert_many_txn(
self,
txn: LoggingTransaction,
@@ -1214,7 +1258,13 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
return await self.runInteraction(
- desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+ desc,
+ self.simple_select_one_txn,
+ table,
+ keyvalues,
+ retcols,
+ allow_none,
+ db_autocommit=True,
)
@overload
@@ -1265,6 +1315,7 @@ class DatabasePool:
keyvalues,
retcol,
allow_none=allow_none,
+ db_autocommit=True,
)
@overload
@@ -1346,7 +1397,12 @@ class DatabasePool:
Results in a list
"""
return await self.runInteraction(
- desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+ desc,
+ self.simple_select_onecol_txn,
+ table,
+ keyvalues,
+ retcol,
+ db_autocommit=True,
)
async def simple_select_list(
@@ -1371,7 +1427,12 @@ class DatabasePool:
A list of dictionaries.
"""
return await self.runInteraction(
- desc, self.simple_select_list_txn, table, keyvalues, retcols
+ desc,
+ self.simple_select_list_txn,
+ table,
+ keyvalues,
+ retcols,
+ db_autocommit=True,
)
@classmethod
@@ -1450,6 +1511,7 @@ class DatabasePool:
chunk,
keyvalues,
retcols,
+ db_autocommit=True,
)
results.extend(rows)
@@ -1548,7 +1610,12 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
- desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+ desc,
+ self.simple_update_one_txn,
+ table,
+ keyvalues,
+ updatevalues,
+ db_autocommit=True,
)
@classmethod
@@ -1607,7 +1674,9 @@ class DatabasePool:
keyvalues: dict of column names and values to select the row with
desc: description of the transaction, for logging and metrics
"""
- await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+ await self.runInteraction(
+ desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
+ )
@staticmethod
def simple_delete_one_txn(
@@ -1646,7 +1715,9 @@ class DatabasePool:
Returns:
The number of deleted rows.
"""
- return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+ return await self.runInteraction(
+ desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
+ )
@staticmethod
def simple_delete_txn(
@@ -1694,7 +1765,13 @@ class DatabasePool:
Number rows deleted
"""
return await self.runInteraction(
- desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+ desc,
+ self.simple_delete_many_txn,
+ table,
+ column,
+ iterable,
+ keyvalues,
+ db_autocommit=True,
)
@staticmethod
@@ -1860,7 +1937,13 @@ class DatabasePool:
"""
return await self.runInteraction(
- desc, self.simple_search_list_txn, table, term, col, retcols
+ desc,
+ self.simple_search_list_txn,
+ table,
+ term,
+ col,
+ retcols,
+ db_autocommit=True,
)
@classmethod
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 85f6b1e3fd..43bf0f649a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,12 +15,15 @@
# limitations under the License.
import logging
import re
+from typing import List
-from synapse.appservice import AppServiceTransaction
+from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config.appservice import load_appservices
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.types import JsonDict
from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -172,15 +175,23 @@ class ApplicationServiceTransactionWorkerStore(
"application_services_state", {"as_id": service.id}, {"state": state}
)
- async def create_appservice_txn(self, service, events):
+ async def create_appservice_txn(
+ self,
+ service: ApplicationService,
+ events: List[EventBase],
+ ephemeral: List[JsonDict],
+ ) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
- with the given list of events.
+ with the given list of events. Ephemeral events are NOT persisted to the
+ database and are not resent if a transaction is retried.
Args:
- service(ApplicationService): The service who the transaction is for.
- events(list<Event>): A list of events to put in the transaction.
+ service: The service who the transaction is for.
+ events: A list of persistent events to put in the transaction.
+ ephemeral: A list of ephemeral events to put in the transaction.
+
Returns:
- AppServiceTransaction: A new transaction.
+ A new transaction.
"""
def _create_appservice_txn(txn):
@@ -207,7 +218,9 @@ class ApplicationServiceTransactionWorkerStore(
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids),
)
- return AppServiceTransaction(service=service, id=new_txn_id, events=events)
+ return AppServiceTransaction(
+ service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+ )
return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
@@ -296,7 +309,9 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
- return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
+ return AppServiceTransaction(
+ service=service, id=entry["txn_id"], events=events, ephemeral=[]
+ )
def _get_last_txn(self, txn, service_id):
txn.execute(
@@ -320,7 +335,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def get_new_events_for_appservice(self, current_id, limit):
- """Get all new evnets"""
+ """Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
sql = (
@@ -351,6 +366,39 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, events
+ async def get_type_stream_id_for_appservice(
+ self, service: ApplicationService, type: str
+ ) -> int:
+ def get_type_stream_id_for_appservice_txn(txn):
+ stream_id_type = "%s_stream_id" % type
+ txn.execute(
+ "SELECT ? FROM application_services_state WHERE as_id=?",
+ (stream_id_type, service.id,),
+ )
+ last_txn_id = txn.fetchone()
+ if last_txn_id is None or last_txn_id[0] is None: # no row exists
+ return 0
+ else:
+ return int(last_txn_id[0])
+
+ return await self.db_pool.runInteraction(
+ "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
+ )
+
+ async def set_type_stream_id_for_appservice(
+ self, service: ApplicationService, type: str, pos: int
+ ) -> None:
+ def set_type_stream_id_for_appservice_txn(txn):
+ stream_id_type = "%s_stream_id" % type
+ txn.execute(
+ "UPDATE ? SET device_list_stream_id = ? WHERE as_id=?",
+ (stream_id_type, pos, service.id),
+ )
+
+ await self.db_pool.runInteraction(
+ "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+ )
+
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
# This is currently empty due to there not being any AS storage functions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index a25a888443..9e66e6648a 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- self.client_ip_last_seen = Cache(
+ self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 88fd97e1df..e662a20d24 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -34,7 +34,8 @@ from synapse.storage.database import (
)
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
-from synapse.util.caches.descriptors import Cache, cached, cachedList
+from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -1004,7 +1005,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
- self.device_id_exists_cache = Cache(
+ self.device_id_exists_cache = DeferredCache(
name="device_id_exists", keylen=2, max_entries=10000
)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index fdb17745f6..ba3b1769b0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1270,6 +1270,10 @@ class PersistEventsStore:
)
def _store_retention_policy_for_room_txn(self, txn, event):
+ if not event.is_state():
+ logger.debug("Ignoring non-state m.room.retention event")
+ return
+
if hasattr(event, "content") and (
"min_lifetime" in event.content or "max_lifetime" in event.content
):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 3ec4d1d9c2..ff150f0be7 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached
+from synapse.util.caches.deferred_cache import DeferredCache
+from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -145,7 +146,7 @@ class EventsWorkerStore(SQLBaseStore):
self._cleanup_old_transaction_ids,
)
- self._get_event_cache = Cache(
+ self._get_event_cache = DeferredCache(
"*getEvent*",
keylen=3,
max_entries=hs.config.caches.event_cache_size,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index ad43bb05ab..f8f4bb9b3f 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
- await self.db_pool.runInteraction(
- "store_server_verify_keys",
- self.db_pool.simple_upsert_many_txn,
+ await self.db_pool.simple_upsert_many(
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore):
"verify_key",
),
value_values=value_values,
+ desc="store_server_verify_keys",
)
invalidate = self._get_server_verify_key.invalidate
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 0acf0617ca..79b01d16f9 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -281,9 +281,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
a_day_in_milliseconds = 24 * 60 * 60 * 1000
now = self._clock.time_msec()
+ # A note on user_agent. Technically a given device can have multiple
+ # user agents, so we need to decide which one to pick. We could have handled this
+ # in number of ways, but given that we don't _that_ much have gone for MAX()
+ # For more details of the other options considered see
+ # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111
sql = """
- INSERT INTO user_daily_visits (user_id, device_id, timestamp)
- SELECT u.user_id, u.device_id, ?
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent)
+ SELECT u.user_id, u.device_id, ?, MAX(u.user_agent)
FROM user_ips AS u
LEFT JOIN (
SELECT user_id, device_id, timestamp FROM user_daily_visits
@@ -294,7 +299,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
WHERE last_seen > ? AND last_seen <= ?
AND udv.timestamp IS NULL AND users.is_guest=0
AND users.appservice_id IS NULL
- GROUP BY u.user_id, u.device_id
+ GROUP BY u.user_id, u.device_id, u.user_agent
"""
# This means that the day has rolled over but there could still
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c79ddff680..5cdf16521c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
@@ -274,6 +275,60 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
}
return results
+ @cached(num_args=2,)
+ async def get_linearized_receipts_for_all_rooms(
+ self, to_key: int, from_key: Optional[int] = None
+ ) -> Dict[str, JsonDict]:
+ """Get receipts for all rooms between two stream_ids.
+
+ Args:
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
+ from the start.
+
+ Returns:
+ A dictionary of roomids to a list of receipts.
+ """
+
+ def f(txn):
+ if from_key:
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id > ? AND stream_id <= ?
+ """
+ txn.execute(sql, [from_key, to_key])
+ else:
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id <= ?
+ """
+
+ txn.execute(sql, [to_key])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ txn_results = await self.db_pool.runInteraction(
+ "get_linearized_receipts_for_all_rooms", f
+ )
+
+ results = {}
+ for row in txn_results:
+ # We want a single event per room, since we want to batch the
+ # receipts by room, event and type.
+ room_event = results.setdefault(
+ row["room_id"],
+ {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ )
+
+ # The content is of the form:
+ # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
+ event_entry = room_event["content"].setdefault(row["event_id"], {})
+ receipt_type = event_entry.setdefault(row["receipt_type"], {})
+
+ receipt_type[row["user_id"]] = db_to_json(row["data"])
+
+ return results
+
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:
diff --git a/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql
new file mode 100644
index 0000000000..b0b5dcddce
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Add new column to user_daily_visits to track user agent
+ALTER TABLE user_daily_visits
+ ADD COLUMN user_agent TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql b/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql
new file mode 100644
index 0000000000..20f5a95a24
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE application_services_state
+ ADD COLUMN read_receipt_stream_id INT,
+ ADD COLUMN presence_stream_id INT;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 7d46090267..59207cadd4 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -208,42 +208,56 @@ class TransactionStore(TransactionWorkerStore):
"""
self._destination_retry_cache.pop(destination, None)
- return await self.db_pool.runInteraction(
- "set_destination_retry_timings",
- self._set_destination_retry_timings,
- destination,
- failure_ts,
- retry_last_ts,
- retry_interval,
- )
+ if self.database_engine.can_native_upsert:
+ return await self.db_pool.runInteraction(
+ "set_destination_retry_timings",
+ self._set_destination_retry_timings_native,
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ db_autocommit=True, # Safe as its a single upsert
+ )
+ else:
+ return await self.db_pool.runInteraction(
+ "set_destination_retry_timings",
+ self._set_destination_retry_timings_emulated,
+ destination,
+ failure_ts,
+ retry_last_ts,
+ retry_interval,
+ )
- def _set_destination_retry_timings(
+ def _set_destination_retry_timings_native(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
+ assert self.database_engine.can_native_upsert
+
+ # Upsert retry time interval if retry_interval is zero (i.e. we're
+ # resetting it) or greater than the existing retry interval.
+ #
+ # WARNING: This is executed in autocommit, so we shouldn't add any more
+ # SQL calls in here (without being very careful).
+ sql = """
+ INSERT INTO destinations (
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT (destination) DO UPDATE SET
+ failure_ts = EXCLUDED.failure_ts,
+ retry_last_ts = EXCLUDED.retry_last_ts,
+ retry_interval = EXCLUDED.retry_interval
+ WHERE
+ EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval IS NULL
+ OR destinations.retry_interval < EXCLUDED.retry_interval
+ """
- if self.database_engine.can_native_upsert:
- # Upsert retry time interval if retry_interval is zero (i.e. we're
- # resetting it) or greater than the existing retry interval.
-
- sql = """
- INSERT INTO destinations (
- destination, failure_ts, retry_last_ts, retry_interval
- )
- VALUES (?, ?, ?, ?)
- ON CONFLICT (destination) DO UPDATE SET
- failure_ts = EXCLUDED.failure_ts,
- retry_last_ts = EXCLUDED.retry_last_ts,
- retry_interval = EXCLUDED.retry_interval
- WHERE
- EXCLUDED.retry_interval = 0
- OR destinations.retry_interval IS NULL
- OR destinations.retry_interval < EXCLUDED.retry_interval
- """
-
- txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
-
- return
+ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
+ def _set_destination_retry_timings_emulated(
+ self, txn, destination, failure_ts, retry_last_ts, retry_interval
+ ):
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 5a390ff2f6..d87ceec6da 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id_tuples: iterable of 2-tuple of user IDs.
"""
- def _add_users_who_share_room_txn(txn):
- self.db_pool.simple_upsert_many_txn(
- txn,
- table="users_who_share_private_rooms",
- key_names=["user_id", "other_user_id", "room_id"],
- key_values=[
- (user_id, other_user_id, room_id)
- for user_id, other_user_id in user_id_tuples
- ],
- value_names=(),
- value_values=None,
- )
-
- await self.db_pool.runInteraction(
- "add_users_who_share_room", _add_users_who_share_room_txn
+ await self.db_pool.simple_upsert_many(
+ table="users_who_share_private_rooms",
+ key_names=["user_id", "other_user_id", "room_id"],
+ key_values=[
+ (user_id, other_user_id, room_id)
+ for user_id, other_user_id in user_id_tuples
+ ],
+ value_names=(),
+ value_values=None,
+ desc="add_users_who_share_room",
)
async def add_users_in_public_rooms(
@@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_ids
"""
- def _add_users_in_public_rooms_txn(txn):
-
- self.db_pool.simple_upsert_many_txn(
- txn,
- table="users_in_public_rooms",
- key_names=["user_id", "room_id"],
- key_values=[(user_id, room_id) for user_id in user_ids],
- value_names=(),
- value_values=None,
- )
-
- await self.db_pool.runInteraction(
- "add_users_in_public_rooms", _add_users_in_public_rooms_txn
+ await self.db_pool.simple_upsert_many(
+ table="users_in_public_rooms",
+ key_names=["user_id", "room_id"],
+ key_values=[(user_id, room_id) for user_id in user_ids],
+ value_names=(),
+ value_values=None,
+ desc="add_users_in_public_rooms",
)
async def delete_all_from_user_dir(self) -> None:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 3d8da48f2d..02d71302ea 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -618,14 +618,7 @@ class _MultiWriterCtxManager:
db_autocommit=True,
)
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
with self.id_gen._lock:
- assert max(self.id_gen._current_positions.values(), default=0) < min(
- self.stream_ids
- )
-
self.id_gen._unfinished_ids.update(self.stream_ids)
if self.multiple_ids is None:
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
new file mode 100644
index 0000000000..f728cd2cf2
--- /dev/null
+++ b/synapse/util/caches/deferred_cache.py
@@ -0,0 +1,292 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import enum
+import threading
+from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
+
+from prometheus_client import Gauge
+
+from twisted.internet import defer
+
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches import register_cache
+from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
+
+cache_pending_metric = Gauge(
+ "synapse_util_caches_cache_pending",
+ "Number of lookups currently pending for this cache",
+ ["name"],
+)
+
+
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+
+class _Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
+
+
+class DeferredCache(Generic[KT, VT]):
+ """Wraps an LruCache, adding support for Deferred results.
+
+ It expects that each entry added with set() will be a Deferred; likewise get()
+ may return an ObservableDeferred.
+ """
+
+ __slots__ = (
+ "cache",
+ "name",
+ "keylen",
+ "thread",
+ "metrics",
+ "_pending_deferred_cache",
+ )
+
+ def __init__(
+ self,
+ name: str,
+ max_entries: int = 1000,
+ keylen: int = 1,
+ tree: bool = False,
+ iterable: bool = False,
+ apply_cache_factor_from_config: bool = True,
+ ):
+ """
+ Args:
+ name: The name of the cache
+ max_entries: Maximum amount of entries that the cache will hold
+ keylen: The length of the tuple used as the cache key. Ignored unless
+ `tree` is True.
+ tree: Use a TreeCache instead of a dict as the underlying cache type
+ iterable: If True, count each item in the cached object as an entry,
+ rather than each cached object
+ apply_cache_factor_from_config: Whether cache factors specified in the
+ config file affect `max_entries`
+ """
+ cache_type = TreeCache if tree else dict
+
+ # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
+ self._pending_deferred_cache = (
+ cache_type()
+ ) # type: MutableMapping[KT, CacheEntry]
+
+ # cache is used for completed results and maps to the result itself, rather than
+ # a Deferred.
+ self.cache = LruCache(
+ max_size=max_entries,
+ keylen=keylen,
+ cache_type=cache_type,
+ size_callback=(lambda d: len(d)) if iterable else None,
+ evicted_callback=self._on_evicted,
+ apply_cache_factor_from_config=apply_cache_factor_from_config,
+ )
+
+ self.name = name
+ self.keylen = keylen
+ self.thread = None # type: Optional[threading.Thread]
+ self.metrics = register_cache(
+ "cache",
+ name,
+ self.cache,
+ collect_callback=self._metrics_collection_callback,
+ )
+
+ @property
+ def max_entries(self):
+ return self.cache.max_size
+
+ def _on_evicted(self, evicted_count):
+ self.metrics.inc_evictions(evicted_count)
+
+ def _metrics_collection_callback(self):
+ cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
+ def get(
+ self,
+ key: KT,
+ default=_Sentinel.sentinel,
+ callback: Optional[Callable[[], None]] = None,
+ update_metrics: bool = True,
+ ):
+ """Looks the key up in the caches.
+
+ Args:
+ key(tuple)
+ default: What is returned if key is not in the caches. If not
+ specified then function throws KeyError instead
+ callback(fn): Gets called when the entry in the cache is invalidated
+ update_metrics (bool): whether to update the cache hit rate metrics
+
+ Returns:
+ Either an ObservableDeferred or the result itself
+ """
+ callbacks = [callback] if callback else []
+ val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+ if val is not _Sentinel.sentinel:
+ val.callbacks.update(callbacks)
+ if update_metrics:
+ self.metrics.inc_hits()
+ return val.deferred
+
+ val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
+ if val is not _Sentinel.sentinel:
+ self.metrics.inc_hits()
+ return val
+
+ if update_metrics:
+ self.metrics.inc_misses()
+
+ if default is _Sentinel.sentinel:
+ raise KeyError()
+ else:
+ return default
+
+ def set(
+ self,
+ key: KT,
+ value: defer.Deferred,
+ callback: Optional[Callable[[], None]] = None,
+ ) -> ObservableDeferred:
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
+ callbacks = [callback] if callback else []
+ self.check_thread()
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = observable.observe()
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
+
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry:
+ existing_entry.invalidate()
+
+ self._pending_deferred_cache[key] = entry
+
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
+ self.cache.set(key, result, entry.callbacks)
+ else:
+ # we're not going to put this entry into the cache, so need
+ # to make sure that the invalidation callbacks are called.
+ # That was probably done when _pending_deferred_cache was
+ # updated, but it's possible that `set` was called without
+ # `invalidate` being previously called, in which case it may
+ # not have been. Either way, let's double-check now.
+ entry.invalidate()
+
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+ return observable
+
+ def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
+ callbacks = [callback] if callback else []
+ self.cache.set(key, value, callbacks=callbacks)
+
+ def invalidate(self, key):
+ self.check_thread()
+ self.cache.pop(key, None)
+
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, which will (a) stop it being returned
+ # for future queries and (b) stop it being persisted as a proper entry
+ # in self.cache.
+ entry = self._pending_deferred_cache.pop(key, None)
+
+ # run the invalidation callbacks now, rather than waiting for the
+ # deferred to resolve.
+ if entry:
+ entry.invalidate()
+
+ def invalidate_many(self, key: KT):
+ self.check_thread()
+ if not isinstance(key, tuple):
+ raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+ self.cache.del_multi(key)
+
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, as above
+ entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
+ if entry_dict is not None:
+ for entry in iterate_tree_cache_entry(entry_dict):
+ entry.invalidate()
+
+ def invalidate_all(self):
+ self.check_thread()
+ self.cache.clear()
+ for entry in self._pending_deferred_cache.values():
+ entry.invalidate()
+ self._pending_deferred_cache.clear()
+
+
+class CacheEntry:
+ __slots__ = ["deferred", "callbacks", "invalidated"]
+
+ def __init__(
+ self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
+ ):
+ self.deferred = deferred
+ self.callbacks = set(callbacks)
+ self.invalidated = False
+
+ def invalidate(self):
+ if not self.invalidated:
+ self.invalidated = True
+ for callback in self.callbacks:
+ callback()
+ self.callbacks.clear()
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 98b34f2223..1f43886804 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,25 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import functools
import inspect
import logging
-import threading
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary
-from prometheus_client import Gauge
-
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-
-from . import register_cache
+from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__)
@@ -55,239 +48,6 @@ class _CachedFunction(Generic[F]):
__call__ = None # type: F
-cache_pending_metric = Gauge(
- "synapse_util_caches_cache_pending",
- "Number of lookups currently pending for this cache",
- ["name"],
-)
-
-_CacheSentinel = object()
-
-
-class CacheEntry:
- __slots__ = ["deferred", "callbacks", "invalidated"]
-
- def __init__(self, deferred, callbacks):
- self.deferred = deferred
- self.callbacks = set(callbacks)
- self.invalidated = False
-
- def invalidate(self):
- if not self.invalidated:
- self.invalidated = True
- for callback in self.callbacks:
- callback()
- self.callbacks.clear()
-
-
-class Cache:
- __slots__ = (
- "cache",
- "name",
- "keylen",
- "thread",
- "metrics",
- "_pending_deferred_cache",
- )
-
- def __init__(
- self,
- name: str,
- max_entries: int = 1000,
- keylen: int = 1,
- tree: bool = False,
- iterable: bool = False,
- apply_cache_factor_from_config: bool = True,
- ):
- """
- Args:
- name: The name of the cache
- max_entries: Maximum amount of entries that the cache will hold
- keylen: The length of the tuple used as the cache key
- tree: Use a TreeCache instead of a dict as the underlying cache type
- iterable: If True, count each item in the cached object as an entry,
- rather than each cached object
- apply_cache_factor_from_config: Whether cache factors specified in the
- config file affect `max_entries`
-
- Returns:
- Cache
- """
- cache_type = TreeCache if tree else dict
- self._pending_deferred_cache = cache_type()
-
- self.cache = LruCache(
- max_size=max_entries,
- keylen=keylen,
- cache_type=cache_type,
- size_callback=(lambda d: len(d)) if iterable else None,
- evicted_callback=self._on_evicted,
- apply_cache_factor_from_config=apply_cache_factor_from_config,
- )
-
- self.name = name
- self.keylen = keylen
- self.thread = None # type: Optional[threading.Thread]
- self.metrics = register_cache(
- "cache",
- name,
- self.cache,
- collect_callback=self._metrics_collection_callback,
- )
-
- @property
- def max_entries(self):
- return self.cache.max_size
-
- def _on_evicted(self, evicted_count):
- self.metrics.inc_evictions(evicted_count)
-
- def _metrics_collection_callback(self):
- cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
-
- def check_thread(self):
- expected_thread = self.thread
- if expected_thread is None:
- self.thread = threading.current_thread()
- else:
- if expected_thread is not threading.current_thread():
- raise ValueError(
- "Cache objects can only be accessed from the main thread"
- )
-
- def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
- """Looks the key up in the caches.
-
- Args:
- key(tuple)
- default: What is returned if key is not in the caches. If not
- specified then function throws KeyError instead
- callback(fn): Gets called when the entry in the cache is invalidated
- update_metrics (bool): whether to update the cache hit rate metrics
-
- Returns:
- Either an ObservableDeferred or the raw result
- """
- callbacks = [callback] if callback else []
- val = self._pending_deferred_cache.get(key, _CacheSentinel)
- if val is not _CacheSentinel:
- val.callbacks.update(callbacks)
- if update_metrics:
- self.metrics.inc_hits()
- return val.deferred
-
- val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
- if val is not _CacheSentinel:
- self.metrics.inc_hits()
- return val
-
- if update_metrics:
- self.metrics.inc_misses()
-
- if default is _CacheSentinel:
- raise KeyError()
- else:
- return default
-
- def set(self, key, value, callback=None):
- if not isinstance(value, defer.Deferred):
- raise TypeError("not a Deferred")
-
- callbacks = [callback] if callback else []
- self.check_thread()
- observable = ObservableDeferred(value, consumeErrors=True)
- observer = observable.observe()
- entry = CacheEntry(deferred=observable, callbacks=callbacks)
-
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry:
- existing_entry.invalidate()
-
- self._pending_deferred_cache[key] = entry
-
- def compare_and_pop():
- """Check if our entry is still the one in _pending_deferred_cache, and
- if so, pop it.
-
- Returns true if the entries matched.
- """
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- return True
-
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
- return False
-
- def cb(result):
- if compare_and_pop():
- self.cache.set(key, result, entry.callbacks)
- else:
- # we're not going to put this entry into the cache, so need
- # to make sure that the invalidation callbacks are called.
- # That was probably done when _pending_deferred_cache was
- # updated, but it's possible that `set` was called without
- # `invalidate` being previously called, in which case it may
- # not have been. Either way, let's double-check now.
- entry.invalidate()
-
- def eb(_fail):
- compare_and_pop()
- entry.invalidate()
-
- # once the deferred completes, we can move the entry from the
- # _pending_deferred_cache to the real cache.
- #
- observer.addCallbacks(cb, eb)
- return observable
-
- def prefill(self, key, value, callback=None):
- callbacks = [callback] if callback else []
- self.cache.set(key, value, callbacks=callbacks)
-
- def invalidate(self, key):
- self.check_thread()
- self.cache.pop(key, None)
-
- # if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, which will (a) stop it being returned
- # for future queries and (b) stop it being persisted as a proper entry
- # in self.cache.
- entry = self._pending_deferred_cache.pop(key, None)
-
- # run the invalidation callbacks now, rather than waiting for the
- # deferred to resolve.
- if entry:
- entry.invalidate()
-
- def invalidate_many(self, key):
- self.check_thread()
- if not isinstance(key, tuple):
- raise TypeError("The cache key must be a tuple not %r" % (type(key),))
- self.cache.del_multi(key)
-
- # if we have a pending lookup for this key, remove it from the
- # _pending_deferred_cache, as above
- entry_dict = self._pending_deferred_cache.pop(key, None)
- if entry_dict is not None:
- for entry in iterate_tree_cache_entry(entry_dict):
- entry.invalidate()
-
- def invalidate_all(self):
- self.check_thread()
- self.cache.clear()
- for entry in self._pending_deferred_cache.values():
- entry.invalidate()
- self._pending_deferred_cache.clear()
-
-
class _CacheDescriptorBase:
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
self.orig = orig
@@ -390,13 +150,13 @@ class CacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable
def __get__(self, obj, owner):
- cache = Cache(
+ cache = DeferredCache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
- )
+ ) # type: DeferredCache[Tuple, Any]
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
@@ -640,9 +400,9 @@ class _CacheContext:
_cache_context_objects = (
WeakValueDictionary()
- ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+ ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
- def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
+ def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
self._cache = cache
self._cache_key = cache_key
@@ -651,7 +411,9 @@ class _CacheContext:
self._cache.invalidate(self._cache_key)
@classmethod
- def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
+ def get_instance(
+ cls, cache, cache_key
+ ): # type: (DeferredCache, CacheKey) -> _CacheContext
"""Returns an instance constructed with the given arguments.
A new instance is only created if none already exists.
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4bc1a67b58..33eae2b7c4 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -64,7 +64,8 @@ class LruCache:
Args:
max_size: The maximum amount of entries the cache can hold
- keylen: The length of the tuple used as the cache key
+ keylen: The length of the tuple used as the cache key. Ignored unless
+ cache_type is `TreeCache`.
cache_type (type):
type of underlying cache to be used. Typically one of dict
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 3e180cafd3..6ce2a3d12b 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -34,7 +34,7 @@ class TTLCache:
self._data = {}
# the _CacheEntries, sorted by expiry time
- self._expiry_list = SortedList()
+ self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
self._timer = timer
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 641093d349..4a301b84e1 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -22,7 +22,7 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=GenericWorkerServer
+ http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 0f016c32eb..c2b10d2c70 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -26,7 +26,7 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=GenericWorkerServer
+ http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
@@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- http_client=None, homeserverToUse=SynapseHomeServer
+ http_client=None, homeserver_to_use=SynapseHomeServer
)
return hs
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 68a4caabbf..2acb8b7603 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -60,7 +60,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events # txn made and saved
+ service=service, events=events, ephemeral=[] # txn made and saved
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -81,7 +81,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events # txn made and saved
+ service=service, events=events, ephemeral=[] # txn made and saved
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@@ -106,7 +106,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with(
- service=service, events=events
+ service=service, events=events, ephemeral=[]
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@@ -202,26 +202,28 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the event to be sent immediately.
service = Mock(id=4)
event = Mock()
- self.queuer.enqueue(service, event)
- self.txn_ctrl.send.assert_called_once_with(service, [event])
+ self.queuer.enqueue_event(service, event)
+ self.txn_ctrl.send.assert_called_once_with(service, [event], [])
def test_send_single_event_with_queue(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y, z: make_deferred_yieldable(d)
+ )
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet.
- self.queuer.enqueue(service, event)
+ self.queuer.enqueue_event(service, event)
# Send more events: expect send() to NOT be called multiple times.
- self.queuer.enqueue(service, event2)
- self.queuer.enqueue(service, event3)
- self.txn_ctrl.send.assert_called_with(service, [event])
+ self.queuer.enqueue_event(service, event2)
+ self.queuer.enqueue_event(service, event3)
+ self.txn_ctrl.send.assert_called_with(service, [event], [])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
- self.txn_ctrl.send.assert_called_with(service, [event2, event3])
+ self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
@@ -239,21 +241,58 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer]
- def do_send(x, y):
+ def do_send(x, y, z):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent
- self.queuer.enqueue(srv1, srv_1_event)
- self.queuer.enqueue(srv1, srv_1_event2)
- self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event])
- self.queuer.enqueue(srv2, srv_2_event)
- self.queuer.enqueue(srv2, srv_2_event2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event])
+ self.queuer.enqueue_event(srv1, srv_1_event)
+ self.queuer.enqueue_event(srv1, srv_1_event2)
+ self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
+ self.queuer.enqueue_event(srv2, srv_2_event)
+ self.queuer.enqueue_event(srv2, srv_2_event2)
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
- self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2])
+ self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
+
+ def test_send_single_ephemeral_no_queue(self):
+ # Expect the event to be sent immediately.
+ service = Mock(id=4, name="service")
+ event_list = [Mock(name="event")]
+ self.queuer.enqueue_ephemeral(service, event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+
+ def test_send_multiple_ephemeral_no_queue(self):
+ # Expect the event to be sent immediately.
+ service = Mock(id=4, name="service")
+ event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
+ self.queuer.enqueue_ephemeral(service, event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
+
+ def test_send_single_ephemeral_with_queue(self):
+ d = defer.Deferred()
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y, z: make_deferred_yieldable(d)
+ )
+ service = Mock(id=4)
+ event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
+ event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
+ event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
+
+ # Send an event and don't resolve it just yet.
+ self.queuer.enqueue_ephemeral(service, event_list_1)
+ # Send more events: expect send() to NOT be called multiple times.
+ self.queuer.enqueue_ephemeral(service, event_list_2)
+ self.queuer.enqueue_ephemeral(service, event_list_3)
+ self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
+ self.assertEquals(1, self.txn_ctrl.send.call_count)
+ # Resolve txn_ctrl.send
+ d.callback(service)
+ # Expect the queued events to be sent
+ self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
+ self.assertEquals(2, self.txn_ctrl.send.call_count)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 2a0b7c1b56..ee4f3da31c 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -18,6 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler
+from synapse.types import RoomStreamToken
from tests.test_utils import make_awaitable
from tests.utils import MockClock
@@ -61,7 +62,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+ yield defer.ensureDeferred(
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+ )
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@@ -80,7 +83,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+ yield defer.ensureDeferred(
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+ )
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@defer.inlineCallbacks
@@ -97,7 +102,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+ yield defer.ensureDeferred(
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+ )
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been.",
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 64e28bc639..9f6f21a6e2 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -66,7 +66,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
"sender": self.requester.user.to_string(),
"content": {"msgtype": "m.text", "body": random_string(5)},
},
- token_id=self.token_id,
txn_id=txn_id,
)
)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 914c82e7a8..8ed67640f8 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(builder.build(prev_event_ids))
+ event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 81ea985b9f..093e2faac7 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -59,7 +59,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
http_client=None,
- homeserverToUse=GenericWorkerServer,
+ homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
- homeserverToUse=GenericWorkerServer,
+ homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 23be1167a3..1853667558 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+ hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 9c4a9c3563..779745ae9d 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
}
builder = factory.for_room_version(room_version, event_dict)
- join_event = self.get_success(builder.build(prev_event_ids))
+ join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 6068d14905..82cf033d4e 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,8 +14,12 @@
# limitations under the License.
import logging
+from mock import patch
+
+from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -36,6 +40,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
+ sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
@@ -43,6 +48,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
+ self.room_creator = self.hs.get_room_creation_handler()
+ self.store = hs.get_datastore()
+
def default_config(self):
conf = super().default_config()
conf["redis"] = {"enabled": "true"}
@@ -53,6 +61,29 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
+ def _create_room(self, room_id: str, user_id: str, tok: str):
+ """Create a room with given room_id
+ """
+
+ # We control the room ID generation by patching out the
+ # `_generate_room_id` method
+ async def generate_room(
+ creator_id: str, is_public: bool, room_version: RoomVersion
+ ):
+ await self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id=creator_id,
+ is_public=is_public,
+ room_version=room_version,
+ )
+ return room_id
+
+ with patch(
+ "synapse.handlers.room.RoomCreationHandler._generate_room_id"
+ ) as mock:
+ mock.side_effect = generate_room
+ self.helper.create_room_as(user_id, tok=tok)
+
def test_basic(self):
"""Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances.
@@ -100,3 +131,189 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2)
+
+ def test_vector_clock_token(self):
+ """Tests that using a stream token with a vector clock component works
+ correctly with basic /sync and /messages usage.
+ """
+
+ self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker1"},
+ )
+
+ worker_hs2 = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "worker2"},
+ )
+
+ sync_hs = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "sync"},
+ )
+
+ # Specially selected room IDs that get persisted on different workers.
+ room_id1 = "!foo:test"
+ room_id2 = "!baz:test"
+
+ self.assertEqual(
+ self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1"
+ )
+ self.assertEqual(
+ self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2"
+ )
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ store = self.hs.get_datastore()
+
+ # Create two room on the different workers.
+ self._create_room(room_id1, user_id, access_token)
+ self._create_room(room_id2, user_id, access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room_id1, user=self.other_user_id, tok=self.other_access_token
+ )
+ self.helper.join(
+ room=room_id2, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # Do an initial sync so that we're up to date.
+ request, channel = self.make_request("GET", "/sync", access_token=access_token)
+ self.render_on_worker(sync_hs, request)
+ next_batch = channel.json_body["next_batch"]
+
+ # We now gut wrench into the events stream MultiWriterIdGenerator on
+ # worker2 to mimic it getting stuck persisting an event. This ensures
+ # that when we send an event on worker1 we end up in a state where
+ # worker2 events stream position lags that on worker1, resulting in a
+ # RoomStreamToken with a non-empty instance map component.
+ #
+ # Worker2's event stream position will not advance until we call
+ # __aexit__ again.
+ actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
+ self.get_success(actx.__aenter__())
+
+ response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
+ first_event_in_room1 = response["event_id"]
+
+ # Assert that the current stream token has an instance map component, as
+ # we are trying to test vector clock tokens.
+ room_stream_token = store.get_room_max_token()
+ self.assertNotEqual(len(room_stream_token.instance_map), 0)
+
+ # Check that syncing still gets the new event, despite the gap in the
+ # stream IDs.
+ request, channel = self.make_request(
+ "GET", "/sync?since={}".format(next_batch), access_token=access_token
+ )
+ self.render_on_worker(sync_hs, request)
+
+ # We should only see the new event and nothing else
+ self.assertIn(room_id1, channel.json_body["rooms"]["join"])
+ self.assertNotIn(room_id2, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"]
+ self.assertListEqual(
+ [first_event_in_room1], [event["event_id"] for event in events]
+ )
+
+ # Get the next batch and makes sure its a vector clock style token.
+ vector_clock_token = channel.json_body["next_batch"]
+ self.assertTrue(vector_clock_token.startswith("m"))
+
+ # Now that we've got a vector clock token we finish the fake persisting
+ # an event we started above.
+ self.get_success(actx.__aexit__(None, None, None))
+
+ # Now try and send an event to the other rooom so that we can test that
+ # the vector clock style token works as a `since` token.
+ response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
+ first_event_in_room2 = response["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/sync?since={}".format(vector_clock_token),
+ access_token=access_token,
+ )
+ self.render_on_worker(sync_hs, request)
+
+ self.assertNotIn(room_id1, channel.json_body["rooms"]["join"])
+ self.assertIn(room_id2, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"]
+ self.assertListEqual(
+ [first_event_in_room2], [event["event_id"] for event in events]
+ )
+
+ next_batch = channel.json_body["next_batch"]
+
+ # We also want to test that the vector clock style token works with
+ # pagination. We do this by sending a couple of new events into the room
+ # and syncing again to get a prev_batch token for each room, then
+ # paginating from there back to the vector clock token.
+ self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
+ self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "GET", "/sync?since={}".format(next_batch), access_token=access_token
+ )
+ self.render_on_worker(sync_hs, request)
+
+ prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][
+ "prev_batch"
+ ]
+ prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][
+ "prev_batch"
+ ]
+
+ # Paginating back in the first room should not produce any results, as
+ # no events have happened in it. This tests that we are correctly
+ # filtering results based on the vector clock portion.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=b".format(
+ room_id1, prev_batch1, vector_clock_token
+ ),
+ access_token=access_token,
+ )
+ self.render_on_worker(sync_hs, request)
+ self.assertListEqual([], channel.json_body["chunk"])
+
+ # Paginating back on the second room should produce the first event
+ # again. This tests that pagination isn't completely broken.
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=b".format(
+ room_id2, prev_batch2, vector_clock_token
+ ),
+ access_token=access_token,
+ )
+ self.render_on_worker(sync_hs, request)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertEqual(
+ channel.json_body["chunk"][0]["event_id"], first_event_in_room2
+ )
+
+ # Paginating forwards should give the same results
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=f".format(
+ room_id1, vector_clock_token, prev_batch1
+ ),
+ access_token=access_token,
+ )
+ self.render_on_worker(sync_hs, request)
+ self.assertListEqual([], channel.json_body["chunk"])
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/{}/messages?from={}&to={}&dir=f".format(
+ room_id2, vector_clock_token, prev_batch2,
+ ),
+ access_token=access_token,
+ )
+ self.render_on_worker(sync_hs, request)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertEqual(
+ channel.json_body["chunk"][0]["event_id"], first_event_in_room2
+ )
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index f5afed017c..8e69b1e9cc 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -20,82 +20,11 @@ from mock import Mock
from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import Cache, cached
+from synapse.util.caches.descriptors import cached
from tests import unittest
-class CacheTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
- self.cache = Cache("test")
-
- def test_empty(self):
- failed = False
- try:
- self.cache.get("foo")
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- def test_hit(self):
- self.cache.prefill("foo", 123)
-
- self.assertEquals(self.cache.get("foo"), 123)
-
- def test_invalidate(self):
- self.cache.prefill(("foo",), 123)
- self.cache.invalidate(("foo",))
-
- failed = False
- try:
- self.cache.get(("foo",))
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- def test_eviction(self):
- cache = Cache("test", max_entries=2)
-
- cache.prefill(1, "one")
- cache.prefill(2, "two")
- cache.prefill(3, "three") # 1 will be evicted
-
- failed = False
- try:
- cache.get(1)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- cache.get(2)
- cache.get(3)
-
- def test_eviction_lru(self):
- cache = Cache("test", max_entries=2)
-
- cache.prefill(1, "one")
- cache.prefill(2, "two")
-
- # Now access 1 again, thus causing 2 to be least-recently used
- cache.get(1)
-
- cache.prefill(3, "three")
-
- failed = False
- try:
- cache.get(2)
- except KeyError:
- failed = True
-
- self.assertTrue(failed)
-
- cache.get(1)
- cache.get(3)
-
-
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index c905a38930..c5c7987349 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
@@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
@@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
txn = yield defer.ensureDeferred(
- self.store.create_appservice_txn(service, events)
+ self.store.create_appservice_txn(service, events, [])
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 1ea35d60c1..d4f9e809db 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id
@defer.inlineCallbacks
- def build(self, prev_event_ids):
+ def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids)
+ self._base_builder.build(prev_event_ids, auth_event_ids)
)
built_event._event_id = self._event_id
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index f5f63d8ed6..759e4cd048 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -15,7 +15,7 @@
# limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest
@@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
Caches produce metrics reflecting their state when scraped.
"""
CACHE_NAME = "cache_metrics_test_fgjkbdfg"
- cache = Cache(CACHE_NAME, max_entries=777)
+ cache = DeferredCache(CACHE_NAME, max_entries=777)
items = {
x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")
diff --git a/tests/unittest.py b/tests/unittest.py
index 6c1661c92c..040b126a27 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union
+from typing import Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch
@@ -364,6 +364,36 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses.
"""
+ # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
+ # when the `request` arg isn't given, so we define an explicit override to
+ # cover that case.
+ @overload
+ def make_request(
+ self,
+ method: Union[bytes, str],
+ path: Union[bytes, str],
+ content: Union[bytes, dict] = b"",
+ access_token: Optional[str] = None,
+ shorthand: bool = True,
+ federation_auth_origin: str = None,
+ content_is_form: bool = False,
+ ) -> Tuple[SynapseRequest, FakeChannel]:
+ ...
+
+ @overload
+ def make_request(
+ self,
+ method: Union[bytes, str],
+ path: Union[bytes, str],
+ content: Union[bytes, dict] = b"",
+ access_token: Optional[str] = None,
+ request: Type[T] = SynapseRequest,
+ shorthand: bool = True,
+ federation_auth_origin: str = None,
+ content_is_form: bool = False,
+ ) -> Tuple[T, FakeChannel]:
+ ...
+
def make_request(
self,
method: Union[bytes, str],
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
new file mode 100644
index 0000000000..9717be56b6
--- /dev/null
+++ b/tests/util/caches/test_deferred_cache.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from functools import partial
+
+from twisted.internet import defer
+
+from synapse.util.caches.deferred_cache import DeferredCache
+
+
+class DeferredCacheTestCase(unittest.TestCase):
+ def test_empty(self):
+ cache = DeferredCache("test")
+ failed = False
+ try:
+ cache.get("foo")
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_hit(self):
+ cache = DeferredCache("test")
+ cache.prefill("foo", 123)
+
+ self.assertEquals(cache.get("foo"), 123)
+
+ def test_invalidate(self):
+ cache = DeferredCache("test")
+ cache.prefill(("foo",), 123)
+ cache.invalidate(("foo",))
+
+ failed = False
+ try:
+ cache.get(("foo",))
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ def test_invalidate_all(self):
+ cache = DeferredCache("testcache")
+
+ callback_record = [False, False]
+
+ def record_callback(idx):
+ callback_record[idx] = True
+
+ # add a couple of pending entries
+ d1 = defer.Deferred()
+ cache.set("key1", d1, partial(record_callback, 0))
+
+ d2 = defer.Deferred()
+ cache.set("key2", d2, partial(record_callback, 1))
+
+ # lookup should return observable deferreds
+ self.assertFalse(cache.get("key1").has_called())
+ self.assertFalse(cache.get("key2").has_called())
+
+ # let one of the lookups complete
+ d2.callback("result2")
+
+ # for now at least, the cache will return real results rather than an
+ # observabledeferred
+ self.assertEqual(cache.get("key2"), "result2")
+
+ # now do the invalidation
+ cache.invalidate_all()
+
+ # lookup should return none
+ self.assertIsNone(cache.get("key1", None))
+ self.assertIsNone(cache.get("key2", None))
+
+ # both callbacks should have been callbacked
+ self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
+ self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
+
+ # letting the other lookup complete should do nothing
+ d1.callback("result1")
+ self.assertIsNone(cache.get("key1", None))
+
+ def test_eviction(self):
+ cache = DeferredCache(
+ "test", max_entries=2, apply_cache_factor_from_config=False
+ )
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+ cache.prefill(3, "three") # 1 will be evicted
+
+ failed = False
+ try:
+ cache.get(1)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(2)
+ cache.get(3)
+
+ def test_eviction_lru(self):
+ cache = DeferredCache(
+ "test", max_entries=2, apply_cache_factor_from_config=False
+ )
+
+ cache.prefill(1, "one")
+ cache.prefill(2, "two")
+
+ # Now access 1 again, thus causing 2 to be least-recently used
+ cache.get(1)
+
+ cache.prefill(3, "three")
+
+ failed = False
+ try:
+ cache.get(2)
+ except KeyError:
+ failed = True
+
+ self.assertTrue(failed)
+
+ cache.get(1)
+ cache.get(3)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 677e925477..3d1f960869 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from functools import partial
import mock
@@ -42,49 +41,6 @@ def run_on_reactor():
return make_deferred_yieldable(d)
-class CacheTestCase(unittest.TestCase):
- def test_invalidate_all(self):
- cache = descriptors.Cache("testcache")
-
- callback_record = [False, False]
-
- def record_callback(idx):
- callback_record[idx] = True
-
- # add a couple of pending entries
- d1 = defer.Deferred()
- cache.set("key1", d1, partial(record_callback, 0))
-
- d2 = defer.Deferred()
- cache.set("key2", d2, partial(record_callback, 1))
-
- # lookup should return observable deferreds
- self.assertFalse(cache.get("key1").has_called())
- self.assertFalse(cache.get("key2").has_called())
-
- # let one of the lookups complete
- d2.callback("result2")
-
- # for now at least, the cache will return real results rather than an
- # observabledeferred
- self.assertEqual(cache.get("key2"), "result2")
-
- # now do the invalidation
- cache.invalidate_all()
-
- # lookup should return none
- self.assertIsNone(cache.get("key1", None))
- self.assertIsNone(cache.get("key2", None))
-
- # both callbacks should have been callbacked
- self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
- self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
-
- # letting the other lookup complete should do nothing
- d1.callback("result1")
- self.assertIsNone(cache.get("key1", None))
-
-
class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
diff --git a/tests/utils.py b/tests/utils.py
index 0c09f5457f..acec74e9e9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,6 +21,7 @@ import time
import uuid
import warnings
from inspect import getcallargs
+from typing import Type
from urllib import parse as urlparse
from mock import Mock, patch
@@ -194,8 +195,8 @@ def setup_test_homeserver(
name="test",
config=None,
reactor=None,
- homeserverToUse=TestHomeServer,
- **kargs
+ homeserver_to_use: Type[HomeServer] = TestHomeServer,
+ **kwargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -218,8 +219,8 @@ def setup_test_homeserver(
config.ldap_enabled = False
- if "clock" not in kargs:
- kargs["clock"] = MockClock()
+ if "clock" not in kwargs:
+ kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
@@ -264,18 +265,20 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- hs = homeserverToUse(
- name,
- config=config,
- version_string="Synapse/tests",
- tls_server_context_factory=Mock(),
- tls_client_options_factory=Mock(),
- reactor=reactor,
- **kargs
+ hs = homeserver_to_use(
+ name, config=config, version_string="Synapse/tests", reactor=reactor,
)
+ # Install @cache_in_self attributes
+ for key, val in kwargs.items():
+ setattr(hs, key, val)
+
+ # Mock TLS
+ hs.tls_server_context_factory = Mock()
+ hs.tls_client_options_factory = Mock()
+
hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
+ if homeserver_to_use == TestHomeServer:
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
@@ -339,7 +342,7 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash
- fed = kargs.get("resource_for_federation", None)
+ fed = kwargs.get("resource_for_federation", None)
if fed:
register_federation_servlets(hs, fed)
|