diff --git a/changelog.d/10556.doc b/changelog.d/10556.doc
new file mode 100644
index 0000000000..7526ae11db
--- /dev/null
+++ b/changelog.d/10556.doc
@@ -0,0 +1 @@
+Minor fix to the `media_repository` developer documentation. Contributed by @cuttingedge1109.
\ No newline at end of file
diff --git a/changelog.d/10566.feature b/changelog.d/10566.feature
new file mode 100644
index 0000000000..04575d76a9
--- /dev/null
+++ b/changelog.d/10566.feature
@@ -0,0 +1 @@
+Allow room creators to send historical events specified by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) in existing room versions.
diff --git a/changelog.d/10643.feature b/changelog.d/10643.feature
new file mode 100644
index 0000000000..bd63a3d258
--- /dev/null
+++ b/changelog.d/10643.feature
@@ -0,0 +1 @@
+Add config option to use non-default manhole password and keys.
\ No newline at end of file
diff --git a/changelog.d/10658.bugfix b/changelog.d/10658.bugfix
new file mode 100644
index 0000000000..a59d402933
--- /dev/null
+++ b/changelog.d/10658.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where room avatars were not included in email notifications.
diff --git a/changelog.d/10697.misc b/changelog.d/10697.misc
new file mode 100644
index 0000000000..a9ad17faf2
--- /dev/null
+++ b/changelog.d/10697.misc
@@ -0,0 +1 @@
+Ensure `rooms.creator` field is always populated for easy lookup in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) usage later.
diff --git a/changelog.d/10704.bugfix b/changelog.d/10704.bugfix
new file mode 100644
index 0000000000..4284cddc2b
--- /dev/null
+++ b/changelog.d/10704.bugfix
@@ -0,0 +1 @@
+Added opentrace logging to help debug #9424.
\ No newline at end of file
diff --git a/changelog.d/10707.misc b/changelog.d/10707.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10707.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10712.feature b/changelog.d/10712.feature
new file mode 100644
index 0000000000..d04db6f26f
--- /dev/null
+++ b/changelog.d/10712.feature
@@ -0,0 +1 @@
+Skip final GC at shutdown to improve restart performance.
diff --git a/changelog.d/10714.feature b/changelog.d/10714.feature
new file mode 100644
index 0000000000..7d18f5c133
--- /dev/null
+++ b/changelog.d/10714.feature
@@ -0,0 +1 @@
+Allow configuration of the oEmbed URLs used for URL previews.
diff --git a/changelog.d/10727.misc b/changelog.d/10727.misc
new file mode 100644
index 0000000000..63fe6e5c7d
--- /dev/null
+++ b/changelog.d/10727.misc
@@ -0,0 +1 @@
+Do not include rooms with unknown room versions in the spaces summary results.
diff --git a/changelog.d/10728.misc b/changelog.d/10728.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10728.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10730.bugfix b/changelog.d/10730.bugfix
new file mode 100644
index 0000000000..f1612d3c08
--- /dev/null
+++ b/changelog.d/10730.bugfix
@@ -0,0 +1 @@
+Fix a bug where the ordering algorithm was skipping the `origin_server_ts` step in the spaces summary resulting in unstable room orderings.
diff --git a/changelog.d/10735.doc b/changelog.d/10735.doc
new file mode 100644
index 0000000000..5d6207afb9
--- /dev/null
+++ b/changelog.d/10735.doc
@@ -0,0 +1 @@
+Clarify admin API documentation on undoing room deletions.
diff --git a/changelog.d/10736.misc b/changelog.d/10736.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10736.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10738.misc b/changelog.d/10738.misc
new file mode 100644
index 0000000000..cef54153dc
--- /dev/null
+++ b/changelog.d/10738.misc
@@ -0,0 +1 @@
+Additional error checking for the `preset` field when creating a room.
diff --git a/changelog.d/10743.bugfix b/changelog.d/10743.bugfix
new file mode 100644
index 0000000000..d597a19870
--- /dev/null
+++ b/changelog.d/10743.bugfix
@@ -0,0 +1 @@
+Fix edge case when persisting events into a room where there are multiple events we previously hadn't calculated auth chains for (and hadn't marked as needing to be calculated).
diff --git a/changelog.d/10744.misc b/changelog.d/10744.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10744.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/changelog.d/10745.misc b/changelog.d/10745.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10745.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/changelog.d/10748.misc b/changelog.d/10748.misc
new file mode 100644
index 0000000000..b9e2c46087
--- /dev/null
+++ b/changelog.d/10748.misc
@@ -0,0 +1 @@
+Add an index to `presence_stream` to hopefully speed up startups a little.
diff --git a/changelog.d/10750.misc b/changelog.d/10750.misc
new file mode 100644
index 0000000000..ded5cf626c
--- /dev/null
+++ b/changelog.d/10750.misc
@@ -0,0 +1 @@
+Refactor event size checking code to simplify searching the codebase for the origins of certain error strings that are occasionally emitted.
\ No newline at end of file
diff --git a/changelog.d/10752.misc b/changelog.d/10752.misc
new file mode 100644
index 0000000000..5f9aa23018
--- /dev/null
+++ b/changelog.d/10752.misc
@@ -0,0 +1 @@
+Move tests relating to rooms having encryption out of the user_directory tests.
\ No newline at end of file
diff --git a/changelog.d/10754.misc b/changelog.d/10754.misc
new file mode 100644
index 0000000000..3b7acff03f
--- /dev/null
+++ b/changelog.d/10754.misc
@@ -0,0 +1 @@
+Minor speed ups when joining large rooms over federation.
diff --git a/changelog.d/10755.misc b/changelog.d/10755.misc
new file mode 100644
index 0000000000..3b7acff03f
--- /dev/null
+++ b/changelog.d/10755.misc
@@ -0,0 +1 @@
+Minor speed ups when joining large rooms over federation.
diff --git a/changelog.d/10756.misc b/changelog.d/10756.misc
new file mode 100644
index 0000000000..3b7acff03f
--- /dev/null
+++ b/changelog.d/10756.misc
@@ -0,0 +1 @@
+Minor speed ups when joining large rooms over federation.
diff --git a/changelog.d/10757.bugfix b/changelog.d/10757.bugfix
new file mode 100644
index 0000000000..bce36ef242
--- /dev/null
+++ b/changelog.d/10757.bugfix
@@ -0,0 +1 @@
+Fix a bug which prevented calls to `/createRoom` that included the `room_alias_name` parameter from being handled by worker processes.
\ No newline at end of file
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 48777dd231..8e524e6509 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -481,32 +481,44 @@ The following fields are returned in the JSON response body:
* `new_room_id` - A string representing the room ID of the new room.
-## Undoing room shutdowns
+## Undoing room deletions
-*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level,
+*Note*: This guide may be outdated by the time you read it. By nature of room deletions being performed at the database level,
the structure can and does change without notice.
-First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
+First, it's important to understand that a room deletion is very destructive. Undoing a deletion is not as simple as pretending it
never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible
to recover at all:
* If the room was invite-only, your users will need to be re-invited.
* If the room no longer has any members at all, it'll be impossible to rejoin.
-* The first user to rejoin will have to do so via an alias on a different server.
+* The first user to rejoin will have to do so via an alias on a different
+ server (or receive an invite from a user on a different server).
With all that being said, if you still want to try and recover the room:
-1. For safety reasons, shut down Synapse.
-2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
- * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
- * The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
-3. Restart Synapse.
+1. If the room was `block`ed, you must unblock it on your server. This can be
+ accomplished as follows:
-You will have to manually handle, if you so choose, the following:
+ 1. For safety reasons, shut down Synapse.
+ 2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
+ * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
+ * The room ID is the same one supplied to the delete room API, not the Content Violation room.
+ 3. Restart Synapse.
-* Aliases that would have been redirected to the Content Violation room.
-* Users that would have been booted from the room (and will have been force-joined to the Content Violation room).
-* Removal of the Content Violation room if desired.
+ This step is unnecessary if `block` was not set.
+
+2. Any room aliases on your server that pointed to the deleted room may have
+ been deleted, or redirected to the Content Violation room. These will need
+ to be restored manually.
+
+3. Users on your server that were in the deleted room will have been kicked
+ from the room. Consider whether you want to update their membership
+ (possibly via the [Edit Room Membership API](room_membership.md)) or let
+ them handle rejoining themselves.
+
+4. If `new_room_user_id` was given, a 'Content Violation' will have been
+ created. Consider whether you want to delete that roomm.
## Deprecated endpoint
@@ -536,7 +548,7 @@ POST /_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
# Forward Extremities Admin API
Enables querying and deleting forward extremities from rooms. When a lot of forward
-extremities accumulate in a room, performance can become degraded. For details, see
+extremities accumulate in a room, performance can become degraded. For details, see
[#1760](https://github.com/matrix-org/synapse/issues/1760).
## Check for forward extremities
@@ -565,7 +577,7 @@ A response as follows will be returned:
## Deleting forward extremities
-**WARNING**: Please ensure you know what you're doing and have read
+**WARNING**: Please ensure you know what you're doing and have read
the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
Under no situations should this API be executed as an automated maintenance task!
diff --git a/docs/manhole.md b/docs/manhole.md
index db92df88dc..715ed840f2 100644
--- a/docs/manhole.md
+++ b/docs/manhole.md
@@ -11,7 +11,7 @@ Note that this will give administrative access to synapse to **all users** with
shell access to the server. It should therefore **not** be enabled in
environments where untrusted users have shell access.
-***
+## Configuring the manhole
To enable it, first uncomment the `manhole` listener configuration in
`homeserver.yaml`. The configuration is slightly different if you're using docker.
@@ -52,16 +52,37 @@ listeners:
type: manhole
```
-#### Accessing synapse manhole
+### Security settings
+
+The following config options are available:
+
+- `username` - The username for the manhole (defaults to `matrix`)
+- `password` - The password for the manhole (defaults to `rabbithole`)
+- `ssh_priv_key` - The path to a private SSH key (defaults to a hardcoded value)
+- `ssh_pub_key` - The path to a public SSH key (defaults to a hardcoded value)
+
+For example:
+
+```yaml
+manhole_settings:
+ username: manhole
+ password: mypassword
+ ssh_priv_key: "/home/synapse/manhole_keys/id_rsa"
+ ssh_pub_key: "/home/synapse/manhole_keys/id_rsa.pub"
+```
+
+
+## Accessing synapse manhole
Then restart synapse, and point an ssh client at port 9000 on localhost, using
-the username `matrix`:
+the username and password configured in `homeserver.yaml` - with the default
+configuration, this would be:
```bash
ssh -p9000 matrix@localhost
```
-The password is `rabbithole`.
+Then enter the password when prompted (the default is `rabbithole`).
This gives a Python REPL in which `hs` gives access to the
`synapse.server.HomeServer` object - which in turn gives access to many other
diff --git a/docs/media_repository.md b/docs/media_repository.md
index 1bf8f16f55..99ee8f1ef7 100644
--- a/docs/media_repository.md
+++ b/docs/media_repository.md
@@ -27,4 +27,4 @@ Remote content is cached under `"remote_content"` directory. Each item of
remote content is assigned a local `"filesystem_id"` to ensure that the
directory structure `"remote_content/server_name/aa/bb/ccccccccdddddddddddd"`
is appropriate. Thumbnails for remote content are stored under
-`"remote_thumbnails/server_name/..."`
+`"remote_thumbnail/server_name/..."`
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 935841dbfa..e15a832220 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -335,6 +335,24 @@ listeners:
# bind_addresses: ['::1', '127.0.0.1']
# type: manhole
+# Connection settings for the manhole
+#
+manhole_settings:
+ # The username for the manhole. This defaults to 'matrix'.
+ #
+ #username: manhole
+
+ # The password for the manhole. This defaults to 'rabbithole'.
+ #
+ #password: mypassword
+
+ # The private and public SSH key pair used to encrypt the manhole traffic.
+ # If these are left unset, then hardcoded and non-secret keys are used,
+ # which could allow traffic to be intercepted if sent over a public network.
+ #
+ #ssh_priv_key_path: CONFDIR/id_rsa
+ #ssh_pub_key_path: CONFDIR/id_rsa.pub
+
# Forward extremities can build up in a room due to networking delays between
# homeservers. Once this happens in a large room, calculation of the state of
# that room can become quite expensive. To mitigate this, once the number of
@@ -1075,6 +1093,27 @@ url_preview_accept_language:
# - en
+# oEmbed allows for easier embedding content from a website. It can be
+# used for generating URLs previews of services which support it.
+#
+oembed:
+ # A default list of oEmbed providers is included with Synapse.
+ #
+ # Uncomment the following to disable using these default oEmbed URLs.
+ # Defaults to 'false'.
+ #
+ #disable_default_providers: true
+
+ # Additional files with oEmbed configuration (each should be in the
+ # form of providers.json).
+ #
+ # By default, this list is empty (so only the default providers.json
+ # is used).
+ #
+ #additional_providers:
+ # - oembed/my_providers.json
+
+
## Captcha ##
# See docs/CAPTCHA_SETUP.md for full details of configuring this.
diff --git a/mypy.ini b/mypy.ini
index f6de668edd..4096f72241 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -90,6 +90,7 @@ files =
tests/test_event_auth.py,
tests/test_utils,
tests/handlers/test_password_providers.py,
+ tests/handlers/test_room.py,
tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py,
tests/handlers/test_sync.py,
@@ -98,6 +99,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
+[mypy-synapse.rest.client.*]
+disallow_untyped_defs = True
+
[mypy-pymacaroons.*]
ignore_missing_imports = True
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 2bbaf5557d..fa6ac6d93a 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -46,6 +46,7 @@ from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore,
)
+from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
@@ -179,6 +180,7 @@ class Store(
EndToEndKeyBackgroundStore,
StatsStore,
PusherWorkerStore,
+ PresenceBackgroundUpdateStore,
):
def execute(self, f, *args, **kwargs):
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi
index fa307483fe..0602a4fa90 100644
--- a/stubs/sortedcontainers/__init__.pyi
+++ b/stubs/sortedcontainers/__init__.pyi
@@ -1,5 +1,6 @@
from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView
from .sortedlist import SortedKeyList, SortedList, SortedListWithKey
+from .sortedset import SortedSet
__all__ = [
"SortedDict",
@@ -9,4 +10,5 @@ __all__ = [
"SortedKeyList",
"SortedList",
"SortedListWithKey",
+ "SortedSet",
]
diff --git a/stubs/sortedcontainers/sortedset.pyi b/stubs/sortedcontainers/sortedset.pyi
new file mode 100644
index 0000000000..f9c2908386
--- /dev/null
+++ b/stubs/sortedcontainers/sortedset.pyi
@@ -0,0 +1,118 @@
+# stub for SortedSet. This is a lightly edited copy of
+# https://github.com/grantjenks/python-sortedcontainers/blob/d0a225d7fd0fb4c54532b8798af3cbeebf97e2d5/sortedcontainers/sortedset.pyi
+# (from https://github.com/grantjenks/python-sortedcontainers/pull/107)
+
+from typing import (
+ AbstractSet,
+ Any,
+ Callable,
+ Generic,
+ Hashable,
+ Iterable,
+ Iterator,
+ List,
+ MutableSet,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+
+# --- Global
+
+_T = TypeVar("_T", bound=Hashable)
+_S = TypeVar("_S", bound=Hashable)
+_SS = TypeVar("_SS", bound=SortedSet)
+_Key = Callable[[_T], Any]
+
+class SortedSet(MutableSet[_T], Sequence[_T]):
+ def __init__(
+ self,
+ iterable: Optional[Iterable[_T]] = ...,
+ key: Optional[_Key[_T]] = ...,
+ ) -> None: ...
+ @classmethod
+ def _fromset(
+ cls, values: Set[_T], key: Optional[_Key[_T]] = ...
+ ) -> SortedSet[_T]: ...
+ @property
+ def key(self) -> Optional[_Key[_T]]: ...
+ def __contains__(self, value: Any) -> bool: ...
+ @overload
+ def __getitem__(self, index: int) -> _T: ...
+ @overload
+ def __getitem__(self, index: slice) -> List[_T]: ...
+ def __delitem__(self, index: Union[int, slice]) -> None: ...
+ def __eq__(self, other: Any) -> bool: ...
+ def __ne__(self, other: Any) -> bool: ...
+ def __lt__(self, other: Iterable[_T]) -> bool: ...
+ def __gt__(self, other: Iterable[_T]) -> bool: ...
+ def __le__(self, other: Iterable[_T]) -> bool: ...
+ def __ge__(self, other: Iterable[_T]) -> bool: ...
+ def __len__(self) -> int: ...
+ def __iter__(self) -> Iterator[_T]: ...
+ def __reversed__(self) -> Iterator[_T]: ...
+ def add(self, value: _T) -> None: ...
+ def _add(self, value: _T) -> None: ...
+ def clear(self) -> None: ...
+ def copy(self: _SS) -> _SS: ...
+ def __copy__(self: _SS) -> _SS: ...
+ def count(self, value: _T) -> int: ...
+ def discard(self, value: _T) -> None: ...
+ def _discard(self, value: _T) -> None: ...
+ def pop(self, index: int = ...) -> _T: ...
+ def remove(self, value: _T) -> None: ...
+ def difference(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def difference_update(
+ self, *iterables: Iterable[_S]
+ ) -> SortedSet[Union[_T, _S]]: ...
+ def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def intersection(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __and__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def intersection_update(
+ self, *iterables: Iterable[_S]
+ ) -> SortedSet[Union[_T, _S]]: ...
+ def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __xor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __rxor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def symmetric_difference_update(
+ self, other: Iterable[_S]
+ ) -> SortedSet[Union[_T, _S]]: ...
+ def __ixor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def union(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __or__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def _update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
+ def __reduce__(
+ self,
+ ) -> Tuple[Type[SortedSet[_T]], Set[_T], Callable[[_T], Any]]: ...
+ def __repr__(self) -> str: ...
+ def _check(self) -> None: ...
+ def bisect_left(self, value: _T) -> int: ...
+ def bisect_right(self, value: _T) -> int: ...
+ def islice(
+ self,
+ start: Optional[int] = ...,
+ stop: Optional[int] = ...,
+ reverse=bool,
+ ) -> Iterator[_T]: ...
+ def irange(
+ self,
+ minimum: Optional[_T] = ...,
+ maximum: Optional[_T] = ...,
+ inclusive: Tuple[bool, bool] = ...,
+ reverse: bool = ...,
+ ) -> Iterator[_T]: ...
+ def index(
+ self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
+ ) -> int: ...
+ def _reset(self, load: int) -> None: ...
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 829061c870..5f0f34119b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -198,6 +198,12 @@ class EventContentFields:
# cf https://github.com/matrix-org/matrix-doc/pull/1772
ROOM_TYPE = "type"
+ # The creator of the room, as used in `m.room.create` events.
+ ROOM_CREATOR = "creator"
+
+ # Used in m.room.guest_access events.
+ GUEST_ACCESS = "guest_access"
+
# Used on normal messages to indicate they were historically imported after the fact
MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
# For "insertion" events to indicate what the next chunk ID should be in
@@ -232,5 +238,11 @@ class HistoryVisibility:
WORLD_READABLE = "world_readable"
+class GuestAccess:
+ CAN_JOIN = "can_join"
+ # anything that is not "can_join" is considered "forbidden", but for completeness:
+ FORBIDDEN = "forbidden"
+
+
class ReadReceiptEventFields:
MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 39e28aff9f..89bda00090 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import atexit
import gc
import logging
import os
@@ -36,6 +37,7 @@ from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import ManholeConfig
from synapse.crypto import context_factory
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
@@ -229,7 +231,12 @@ def listen_metrics(bind_addresses, port):
start_http_server(port, addr=host, registry=RegistryProxy)
-def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: dict):
+def listen_manhole(
+ bind_addresses: Iterable[str],
+ port: int,
+ manhole_settings: ManholeConfig,
+ manhole_globals: dict,
+):
# twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
# warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
# suppress the warning for now.
@@ -244,7 +251,7 @@ def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: di
listen_tcp(
bind_addresses,
port,
- manhole(username="matrix", password="rabbithole", globals=manhole_globals),
+ manhole(settings=manhole_settings, globals=manhole_globals),
)
@@ -403,6 +410,12 @@ async def start(hs: "HomeServer"):
gc.collect()
gc.freeze()
+ # Speed up shutdowns by freezing all allocated objects. This moves everything
+ # into the permanent generation and excludes them from the final GC.
+ # Unfortunately only works on Python 3.7
+ if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
+ atexit.register(gc.freeze)
+
def setup_sentry(hs):
"""Enable sentry integration, if enabled in configuration
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 9b71dd75e6..2eb8d5a79c 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -395,7 +395,10 @@ class GenericWorkerServer(HomeServer):
self._listen_http(listener)
elif listener.type == "manhole":
_base.listen_manhole(
- listener.bind_addresses, listener.port, manhole_globals={"hs": self}
+ listener.bind_addresses,
+ listener.port,
+ manhole_settings=self.config.server.manhole_settings,
+ manhole_globals={"hs": self},
)
elif listener.type == "metrics":
if not self.config.enable_metrics:
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7dae163c1a..708db86f5d 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -291,7 +291,10 @@ class SynapseHomeServer(HomeServer):
)
elif listener.type == "manhole":
_base.listen_manhole(
- listener.bind_addresses, listener.port, manhole_globals={"hs": self}
+ listener.bind_addresses,
+ listener.port,
+ manhole_settings=self.config.server.manhole_settings,
+ manhole_globals={"hs": self},
)
elif listener.type == "replication":
services = listen_tcp(
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 1f42a51857..442f1b9ac0 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -30,6 +30,7 @@ from .key import KeyConfig
from .logger import LoggingConfig
from .metrics import MetricsConfig
from .modules import ModulesConfig
+from .oembed import OembedConfig
from .oidc import OIDCConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
@@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig):
LoggingConfig,
RatelimitConfig,
ContentRepositoryConfig,
+ OembedConfig,
CaptchaConfig,
VoipConfig,
RegistrationConfig,
diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py
new file mode 100644
index 0000000000..09267b5eef
--- /dev/null
+++ b/synapse/config/oembed.py
@@ -0,0 +1,180 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import re
+from typing import Any, Dict, Iterable, List, Pattern
+from urllib import parse as urlparse
+
+import attr
+import pkg_resources
+
+from synapse.types import JsonDict
+
+from ._base import Config, ConfigError
+from ._util import validate_config
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class OEmbedEndpointConfig:
+ # The API endpoint to fetch.
+ api_endpoint: str
+ # The patterns to match.
+ url_patterns: List[Pattern]
+
+
+class OembedConfig(Config):
+ """oEmbed Configuration"""
+
+ section = "oembed"
+
+ def read_config(self, config, **kwargs):
+ oembed_config: Dict[str, Any] = config.get("oembed") or {}
+
+ # A list of patterns which will be used.
+ self.oembed_patterns: List[OEmbedEndpointConfig] = list(
+ self._parse_and_validate_providers(oembed_config)
+ )
+
+ def _parse_and_validate_providers(
+ self, oembed_config: dict
+ ) -> Iterable[OEmbedEndpointConfig]:
+ """Extract and parse the oEmbed providers from the given JSON file.
+
+ Returns a generator which yields the OidcProviderConfig objects
+ """
+ # Whether to use the packaged providers.json file.
+ if not oembed_config.get("disable_default_providers") or False:
+ providers = json.load(
+ pkg_resources.resource_stream("synapse", "res/providers.json")
+ )
+ yield from self._parse_and_validate_provider(
+ providers, config_path=("oembed",)
+ )
+
+ # The JSON files which includes additional provider information.
+ for i, file in enumerate(oembed_config.get("additional_providers") or []):
+ # TODO Error checking.
+ with open(file) as f:
+ providers = json.load(f)
+
+ yield from self._parse_and_validate_provider(
+ providers,
+ config_path=(
+ "oembed",
+ "additional_providers",
+ f"<item {i}>",
+ ),
+ )
+
+ def _parse_and_validate_provider(
+ self, providers: List[JsonDict], config_path: Iterable[str]
+ ) -> Iterable[OEmbedEndpointConfig]:
+ # Ensure it is the proper form.
+ validate_config(
+ _OEMBED_PROVIDER_SCHEMA,
+ providers,
+ config_path=config_path,
+ )
+
+ # Parse it and yield each result.
+ for provider in providers:
+ # Each provider might have multiple API endpoints, each which
+ # might have multiple patterns to match.
+ for endpoint in provider["endpoints"]:
+ api_endpoint = endpoint["url"]
+ patterns = [
+ self._glob_to_pattern(glob, config_path)
+ for glob in endpoint["schemes"]
+ ]
+ yield OEmbedEndpointConfig(api_endpoint, patterns)
+
+ def _glob_to_pattern(self, glob: str, config_path: Iterable[str]) -> Pattern:
+ """
+ Convert the glob into a sane regular expression to match against. The
+ rules followed will be slightly different for the domain portion vs.
+ the rest.
+
+ 1. The scheme must be one of HTTP / HTTPS (and have no globs).
+ 2. The domain can have globs, but we limit it to characters that can
+ reasonably be a domain part.
+ TODO: This does not attempt to handle Unicode domain names.
+ TODO: The domain should not allow wildcard TLDs.
+ 3. Other parts allow a glob to be any one, or more, characters.
+ """
+ results = urlparse.urlparse(glob)
+
+ # Ensure the scheme does not have wildcards (and is a sane scheme).
+ if results.scheme not in {"http", "https"}:
+ raise ConfigError(f"Insecure oEmbed scheme: {results.scheme}", config_path)
+
+ pattern = urlparse.urlunparse(
+ [
+ results.scheme,
+ re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+ ]
+ + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+ )
+ return re.compile(pattern)
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ # oEmbed allows for easier embedding content from a website. It can be
+ # used for generating URLs previews of services which support it.
+ #
+ oembed:
+ # A default list of oEmbed providers is included with Synapse.
+ #
+ # Uncomment the following to disable using these default oEmbed URLs.
+ # Defaults to 'false'.
+ #
+ #disable_default_providers: true
+
+ # Additional files with oEmbed configuration (each should be in the
+ # form of providers.json).
+ #
+ # By default, this list is empty (so only the default providers.json
+ # is used).
+ #
+ #additional_providers:
+ # - oembed/my_providers.json
+ """
+
+
+_OEMBED_PROVIDER_SCHEMA = {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "provider_name": {"type": "string"},
+ "provider_url": {"type": "string"},
+ "endpoints": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "schemes": {
+ "type": "array",
+ "items": {"type": "string"},
+ },
+ "url": {"type": "string"},
+ "formats": {"type": "array", "items": {"type": "string"}},
+ "discovery": {"type": "boolean"},
+ },
+ "required": ["schemes", "url"],
+ },
+ },
+ },
+ "required": ["provider_name", "provider_url", "endpoints"],
+ },
+}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index d2c900f50c..7b9109a592 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -25,11 +25,14 @@ import attr
import yaml
from netaddr import AddrFormatError, IPNetwork, IPSet
+from twisted.conch.ssh.keys import Key
+
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_server_name
from ._base import Config, ConfigError
+from ._util import validate_config
logger = logging.Logger(__name__)
@@ -216,6 +219,16 @@ class ListenerConfig:
http_options = attr.ib(type=Optional[HttpListenerConfig], default=None)
+@attr.s(frozen=True)
+class ManholeConfig:
+ """Object describing the configuration of the manhole"""
+
+ username = attr.ib(type=str, validator=attr.validators.instance_of(str))
+ password = attr.ib(type=str, validator=attr.validators.instance_of(str))
+ priv_key = attr.ib(type=Optional[Key])
+ pub_key = attr.ib(type=Optional[Key])
+
+
class ServerConfig(Config):
section = "server"
@@ -649,6 +662,41 @@ class ServerConfig(Config):
)
)
+ manhole_settings = config.get("manhole_settings") or {}
+ validate_config(
+ _MANHOLE_SETTINGS_SCHEMA, manhole_settings, ("manhole_settings",)
+ )
+
+ manhole_username = manhole_settings.get("username", "matrix")
+ manhole_password = manhole_settings.get("password", "rabbithole")
+ manhole_priv_key_path = manhole_settings.get("ssh_priv_key_path")
+ manhole_pub_key_path = manhole_settings.get("ssh_pub_key_path")
+
+ manhole_priv_key = None
+ if manhole_priv_key_path is not None:
+ try:
+ manhole_priv_key = Key.fromFile(manhole_priv_key_path)
+ except Exception as e:
+ raise ConfigError(
+ f"Failed to read manhole private key file {manhole_priv_key_path}"
+ ) from e
+
+ manhole_pub_key = None
+ if manhole_pub_key_path is not None:
+ try:
+ manhole_pub_key = Key.fromFile(manhole_pub_key_path)
+ except Exception as e:
+ raise ConfigError(
+ f"Failed to read manhole public key file {manhole_pub_key_path}"
+ ) from e
+
+ self.manhole_settings = ManholeConfig(
+ username=manhole_username,
+ password=manhole_password,
+ priv_key=manhole_priv_key,
+ pub_key=manhole_pub_key,
+ )
+
metrics_port = config.get("metrics_port")
if metrics_port:
logger.warning(METRICS_PORT_WARNING)
@@ -715,7 +763,7 @@ class ServerConfig(Config):
if not isinstance(templates_config, dict):
raise ConfigError("The 'templates' section must be a dictionary")
- self.custom_template_directory = templates_config.get(
+ self.custom_template_directory: Optional[str] = templates_config.get(
"custom_template_directory"
)
if self.custom_template_directory is not None and not isinstance(
@@ -727,7 +775,13 @@ class ServerConfig(Config):
return any(listener.tls for listener in self.listeners)
def generate_config_section(
- self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
+ self,
+ server_name,
+ data_dir_path,
+ open_private_ports,
+ listeners,
+ config_dir_path,
+ **kwargs,
):
ip_range_blacklist = "\n".join(
" # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
@@ -1068,6 +1122,24 @@ class ServerConfig(Config):
# bind_addresses: ['::1', '127.0.0.1']
# type: manhole
+ # Connection settings for the manhole
+ #
+ manhole_settings:
+ # The username for the manhole. This defaults to 'matrix'.
+ #
+ #username: manhole
+
+ # The password for the manhole. This defaults to 'rabbithole'.
+ #
+ #password: mypassword
+
+ # The private and public SSH key pair used to encrypt the manhole traffic.
+ # If these are left unset, then hardcoded and non-secret keys are used,
+ # which could allow traffic to be intercepted if sent over a public network.
+ #
+ #ssh_priv_key_path: %(config_dir_path)s/id_rsa
+ #ssh_pub_key_path: %(config_dir_path)s/id_rsa.pub
+
# Forward extremities can build up in a room due to networking delays between
# homeservers. Once this happens in a large room, calculation of the state of
# that room can become quite expensive. To mitigate this, once the number of
@@ -1436,3 +1508,14 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None:
if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
+
+
+_MANHOLE_SETTINGS_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "username": {"type": "string"},
+ "password": {"type": "string"},
+ "ssh_priv_key_path": {"type": "string"},
+ "ssh_pub_key_path": {"type": "string"},
+ },
+}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c3a0c10499..b63a1afe93 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -216,21 +216,18 @@ def check(
def _check_size_limits(event: EventBase) -> None:
- def too_big(field):
- raise EventSizeError("%s too large" % (field,))
-
if len(event.user_id) > 255:
- too_big("user_id")
+ raise EventSizeError("'user_id' too large")
if len(event.room_id) > 255:
- too_big("room_id")
+ raise EventSizeError("'room_id' too large")
if event.is_state() and len(event.state_key) > 255:
- too_big("state_key")
+ raise EventSizeError("'state_key' too large")
if len(event.type) > 255:
- too_big("type")
+ raise EventSizeError("'type' too large")
if len(event.event_id) > 255:
- too_big("event_id")
+ raise EventSizeError("'event_id' too large")
if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE:
- too_big("event")
+ raise EventSizeError("event too large")
def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 6a05a65305..955cfa2207 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,10 +15,7 @@
import logging
from typing import TYPE_CHECKING, Optional
-import synapse.types
-from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
-from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -115,68 +112,3 @@ class BaseHandler:
burst_count=burst_count,
update=update,
)
-
- async def maybe_kick_guest_users(self, event, context=None):
- # Technically this function invalidates current_state by changing it.
- # Hopefully this isn't that important to the caller.
- if event.type == EventTypes.GuestAccess:
- guest_access = event.content.get("guest_access", "forbidden")
- if guest_access != "can_join":
- if context:
- current_state_ids = await context.get_current_state_ids()
- current_state_dict = await self.store.get_events(
- list(current_state_ids.values())
- )
- current_state = list(current_state_dict.values())
- else:
- current_state_map = await self.state_handler.get_current_state(
- event.room_id
- )
- current_state = list(current_state_map.values())
-
- logger.info("maybe_kick_guest_users %r", current_state)
- await self.kick_guest_users(current_state)
-
- async def kick_guest_users(self, current_state):
- for member_event in current_state:
- try:
- if member_event.type != EventTypes.Member:
- continue
-
- target_user = UserID.from_string(member_event.state_key)
- if not self.hs.is_mine(target_user):
- continue
-
- if member_event.content["membership"] not in {
- Membership.JOIN,
- Membership.INVITE,
- }:
- continue
-
- if (
- "kind" not in member_event.content
- or member_event.content["kind"] != "guest"
- ):
- continue
-
- # We make the user choose to leave, rather than have the
- # event-sender kick them. This is partially because we don't
- # need to worry about power levels, and partially because guest
- # users are a concept which doesn't hugely work over federation,
- # and having homeservers have their own users leave keeps more
- # of that decision-making and control local to the guest-having
- # homeserver.
- requester = synapse.types.create_requester(
- target_user, is_guest=True, authenticated_entity=self.server_name
- )
- handler = self.hs.get_room_member_handler()
- await handler.update_membership(
- requester,
- target_user,
- member_event.room_id,
- "leave",
- ratelimit=False,
- require_consent=False,
- )
- except Exception as e:
- logger.exception("Error kicking guest user: %s" % (e,))
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index daf1d3bfb3..77df9185f6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -507,6 +507,7 @@ class FederationHandler(BaseHandler):
await self.store.upsert_room_on_join(
room_id=room_id,
room_version=room_version_obj,
+ auth_events=auth_chain,
)
max_stream_id = await self._persist_auth_tree(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9f055f00cf..afeb2892dc 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -36,6 +36,7 @@ from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
+ GuestAccess,
Membership,
RejectedReason,
RoomEncryptionAlgorithms,
@@ -53,7 +54,6 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers._base import BaseHandler
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@@ -116,7 +116,7 @@ class _NewEventInfo:
claimed_auth_event_map: StateMap[EventBase]
-class FederationEventHandler(BaseHandler):
+class FederationEventHandler:
"""Handles events that originated from federation.
Responsible for handing incoming events and passing them on to the rest
@@ -124,8 +124,6 @@ class FederationEventHandler(BaseHandler):
"""
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -136,11 +134,15 @@ class FederationEventHandler(BaseHandler):
self._message_handler = hs.get_message_handler()
self.action_generator = hs.get_action_generator()
self._state_resolution_handler = hs.get_state_resolution_handler()
+ # avoid a circular dependency by deferring execution here
+ self._get_room_member_handler = hs.get_room_member_handler
self.federation_client = hs.get_federation_client()
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self._notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
+ self._server_name = hs.hostname
self._instance_name = hs.get_instance_name()
self.config = hs.config
@@ -221,7 +223,7 @@ class FederationEventHandler(BaseHandler):
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
is_in_room = await self._event_auth_handler.check_host_in_room(
- room_id, self.server_name
+ room_id, self._server_name
)
if not is_in_room:
logger.info(
@@ -434,7 +436,7 @@ class FederationEventHandler(BaseHandler):
server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.)
"""
- if dest == self.server_name:
+ if dest == self._server_name:
raise SynapseError(400, "Can't backfill from self.")
events = await self.federation_client.backfill(
@@ -1023,9 +1025,15 @@ class FederationEventHandler(BaseHandler):
return
# Skip processing a marker event if the room version doesn't
- # support it.
+ # support it or the event is not from the room creator.
room_version = await self.store.get_room_version(marker_event.room_id)
- if not room_version.msc2716_historical:
+ create_event = await self.store.get_create_event_for_room(marker_event.room_id)
+ room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+ if (
+ not room_version.msc2716_historical
+ or not self.config.experimental.msc2716_enabled
+ or marker_event.sender != room_creator
+ ):
return
logger.debug("_handle_marker_event: received %s", marker_event)
@@ -1321,9 +1329,7 @@ class FederationEventHandler(BaseHandler):
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled, origin=origin)
-
- if event.type == EventTypes.GuestAccess and not context.rejected:
- await self.maybe_kick_guest_users(event)
+ await self._maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
@@ -1334,6 +1340,18 @@ class FederationEventHandler(BaseHandler):
return context
+ async def _maybe_kick_guest_users(self, event: EventBase) -> None:
+ if event.type != EventTypes.GuestAccess:
+ return
+
+ guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+ if guest_access == GuestAccess.CAN_JOIN:
+ return
+
+ current_state_map = await self.state_handler.get_current_state(event.room_id)
+ current_state = list(current_state_map.values())
+ await self._get_room_member_handler().kick_guest_users(current_state)
+
async def _check_for_soft_fail(
self,
event: EventBase,
@@ -1787,7 +1805,7 @@ class FederationEventHandler(BaseHandler):
event_pos = PersistedEventPosition(
self._instance_name, event.internal_metadata.stream_ordering
)
- self.notifier.on_new_room_event(
+ self._notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 101a29c6d3..bf0fef1510 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -27,6 +27,7 @@ from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
+ GuestAccess,
Membership,
RelationTypes,
UserTypes,
@@ -426,7 +427,7 @@ class EventCreationHandler:
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
- # This is only used to get at ratelimit function, and maybe_kick_guest_users
+ # This is only used to get at ratelimit function
self.base_handler = BaseHandler(hs)
# We arbitrarily limit concurrent event creation for a room to 5.
@@ -1306,7 +1307,7 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
- await self.base_handler.maybe_kick_guest_users(event, context)
+ await self._maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases.
@@ -1393,6 +1394,9 @@ class EventCreationHandler:
allow_none=True,
)
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
# we can make some additional checks now if we have the original event.
if original_event:
if original_event.type == EventTypes.Create:
@@ -1404,6 +1408,28 @@ class EventCreationHandler:
if original_event.type == EventTypes.ServerACL:
raise AuthError(403, "Redacting server ACL events is not permitted")
+ # Add a little safety stop-gap to prevent people from trying to
+ # redact MSC2716 related events when they're in a room version
+ # which does not support it yet. We allow people to use MSC2716
+ # events in existing room versions but only from the room
+ # creator since it does not require any changes to the auth
+ # rules and in effect, the redaction algorithm . In the
+ # supported room version, we add the `historical` power level to
+ # auth the MSC2716 related events and adjust the redaction
+ # algorthim to keep the `historical` field around (redacting an
+ # event should only strip fields which don't affect the
+ # structural protocol level).
+ is_msc2716_event = (
+ original_event.type == EventTypes.MSC2716_INSERTION
+ or original_event.type == EventTypes.MSC2716_CHUNK
+ or original_event.type == EventTypes.MSC2716_MARKER
+ )
+ if not room_version_obj.msc2716_historical and is_msc2716_event:
+ raise AuthError(
+ 403,
+ "Redacting MSC2716 events is not supported in this room version",
+ )
+
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
@@ -1411,9 +1437,6 @@ class EventCreationHandler:
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
- room_version = await self.store.get_room_version_id(event.room_id)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
if event_auth.check_redaction(
room_version_obj, event, auth_events=auth_events
):
@@ -1471,6 +1494,28 @@ class EventCreationHandler:
return event
+ async def _maybe_kick_guest_users(
+ self, event: EventBase, context: EventContext
+ ) -> None:
+ if event.type != EventTypes.GuestAccess:
+ return
+
+ guest_access = event.content.get(EventContentFields.GUEST_ACCESS)
+ if guest_access == GuestAccess.CAN_JOIN:
+ return
+
+ current_state_ids = await context.get_current_state_ids()
+
+ # since this is a client-generated event, it cannot be an outlier and we must
+ # therefore have the state ids.
+ assert current_state_ids is not None
+ current_state_dict = await self.store.get_events(
+ list(current_state_ids.values())
+ )
+ current_state = list(current_state_dict.values())
+ logger.info("maybe_kick_guest_users %r", current_state)
+ await self.hs.get_room_member_handler().kick_guest_users(current_state)
+
async def _bump_active_time(self, user: UserID) -> None:
try:
presence = self.hs.get_presence_handler()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b33fe09f77..0235fd09b4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,7 +25,9 @@ from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
+ EventContentFields,
EventTypes,
+ GuestAccess,
HistoryVisibility,
JoinRules,
Membership,
@@ -909,7 +911,12 @@ class RoomCreationHandler(BaseHandler):
)
return last_stream_id
- config = self._presets_dict[preset_config]
+ try:
+ config = self._presets_dict[preset_config]
+ except KeyError:
+ raise SynapseError(
+ 400, f"'{preset_config}' is not a valid preset", errcode=Codes.BAD_JSON
+ )
creation_content.update({"creator": creator_id})
await send(etype=EventTypes.Create, content=creation_content)
@@ -988,7 +995,8 @@ class RoomCreationHandler(BaseHandler):
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
last_sent_stream_id = await send(
- etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
+ etype=EventTypes.GuestAccess,
+ content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
)
for (etype, state_key), content in initial_state.items():
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 6d433fad41..92bb75c848 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -19,7 +19,13 @@ from typing import TYPE_CHECKING, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
-from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ GuestAccess,
+ HistoryVisibility,
+ JoinRules,
+)
from synapse.api.errors import (
Codes,
HttpResponseException,
@@ -336,8 +342,8 @@ class RoomListHandler(BaseHandler):
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
- guest = guest_event.content.get("guest_access", None)
- result["guest_can_join"] = guest == "can_join"
+ guest = guest_event.content.get(EventContentFields.GUEST_ACCESS)
+ result["guest_can_join"] = guest == GuestAccess.CAN_JOIN
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 401b84aad1..4390201641 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -23,6 +23,7 @@ from synapse.api.constants import (
AccountDataTypes,
EventContentFields,
EventTypes,
+ GuestAccess,
Membership,
)
from synapse.api.errors import (
@@ -44,6 +45,7 @@ from synapse.types import (
RoomID,
StateMap,
UserID,
+ create_requester,
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer
@@ -70,6 +72,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
+ self._server_name = hs.hostname
self.federation_handler = hs.get_federation_handler()
self.directory_handler = hs.get_directory_handler()
@@ -115,9 +118,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
)
- # This is only used to get at ratelimit function, and
- # maybe_kick_guest_users. It's fine there are multiple of these as
- # it doesn't store state.
+ # This is only used to get at the ratelimit function. It's fine there are
+ # multiple of these as it doesn't store state.
self.base_handler = BaseHandler(hs)
@abc.abstractmethod
@@ -1095,10 +1097,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return bool(
guest_access
and guest_access.content
- and "guest_access" in guest_access.content
- and guest_access.content["guest_access"] == "can_join"
+ and guest_access.content.get(EventContentFields.GUEST_ACCESS)
+ == GuestAccess.CAN_JOIN
)
+ async def kick_guest_users(self, current_state: Iterable[EventBase]) -> None:
+ """Kick any local guest users from the room.
+
+ This is called when the room state changes from guests allowed to not-allowed.
+
+ Params:
+ current_state: the current state of the room. We will iterate this to look
+ for guest users to kick.
+ """
+ for member_event in current_state:
+ try:
+ if member_event.type != EventTypes.Member:
+ continue
+
+ if not self.hs.is_mine_id(member_event.state_key):
+ continue
+
+ if member_event.content["membership"] not in {
+ Membership.JOIN,
+ Membership.INVITE,
+ }:
+ continue
+
+ if (
+ "kind" not in member_event.content
+ or member_event.content["kind"] != "guest"
+ ):
+ continue
+
+ # We make the user choose to leave, rather than have the
+ # event-sender kick them. This is partially because we don't
+ # need to worry about power levels, and partially because guest
+ # users are a concept which doesn't hugely work over federation,
+ # and having homeservers have their own users leave keeps more
+ # of that decision-making and control local to the guest-having
+ # homeserver.
+ target_user = UserID.from_string(member_event.state_key)
+ requester = create_requester(
+ target_user, is_guest=True, authenticated_entity=self._server_name
+ )
+ handler = self.hs.get_room_member_handler()
+ await handler.update_membership(
+ requester,
+ target_user,
+ member_event.room_id,
+ "leave",
+ ratelimit=False,
+ require_consent=False,
+ )
+ except Exception as e:
+ logger.exception("Error kicking guest user: %s" % (e,))
+
async def lookup_room_alias(
self, room_alias: RoomAlias
) -> Tuple[RoomID, List[str]]:
@@ -1352,7 +1406,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room")
- self._server_name = hs.hostname
async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 906985c754..4bc9c73e6e 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,7 +28,14 @@ from synapse.api.constants import (
Membership,
RoomTypes,
)
-from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ NotFoundError,
+ StoreError,
+ SynapseError,
+ UnsupportedRoomVersionError,
+)
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
@@ -814,7 +821,12 @@ class RoomSummaryHandler:
logger.info("room %s is unknown, omitting from summary", room_id)
return False
- room_version = await self._store.get_room_version(room_id)
+ try:
+ room_version = await self._store.get_room_version(room_id)
+ except UnsupportedRoomVersionError:
+ # If a room with an unsupported room version is encountered, ignore
+ # it to avoid breaking the entire summary response.
+ return False
# Include the room if it has join rules of public or knock.
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""))
@@ -1139,25 +1151,26 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool:
_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]")
-def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]:
+def _child_events_comparison_key(
+ child: EventBase,
+) -> Tuple[bool, Optional[str], int, str]:
"""
Generate a value for comparing two child events for ordering.
- The rules for ordering are supposed to be:
+ The rules for ordering are:
1. The 'order' key, if it is valid.
- 2. The 'origin_server_ts' of the 'm.room.create' event.
+ 2. The 'origin_server_ts' of the 'm.space.child' event.
3. The 'room_id'.
- But we skip step 2 since we may not have any state from the room.
-
Args:
child: The event for generating a comparison key.
Returns:
The comparison key as a tuple of:
False if the ordering is valid.
- The ordering field.
+ The 'order' field or None if it is not given or invalid.
+ The 'origin_server_ts' field.
The room ID.
"""
order = child.content.get("order")
@@ -1168,4 +1181,4 @@ def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str],
order = None
# Items without an order come last.
- return (order is None, order, child.room_id)
+ return (order is None, order, child.origin_server_ts, child.room_id)
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 3fd89af2a4..3a4c41c9ff 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
from typing_extensions import Counter as CounterType
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict
@@ -273,7 +273,9 @@ class StatsHandler:
elif typ == EventTypes.CanonicalAlias:
room_state["canonical_alias"] = event_content.get("alias")
elif typ == EventTypes.GuestAccess:
- room_state["guest_access"] = event_content.get("guest_access")
+ room_state["guest_access"] = event_content.get(
+ EventContentFields.GUEST_ACCESS
+ )
for room_id, state in room_to_state_updates.items():
logger.debug("Updating room_stats_state for %s: %s", room_id, state)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 86c3c7f0df..e017b28cd2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -505,10 +505,13 @@ class SyncHandler:
else:
limited = False
+ log_kv({"limited": limited})
+
if potential_recents:
recents = sync_config.filter_collection.filter_room_timeline(
potential_recents
)
+ log_kv({"recents_after_sync_filtering": len(recents)})
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
@@ -526,6 +529,7 @@ class SyncHandler:
recents,
always_include_ids=current_state_ids,
)
+ log_kv({"recents_after_visibility_filtering": len(recents)})
else:
recents = []
@@ -566,10 +570,15 @@ class SyncHandler:
events, end_key = await self.store.get_recent_events_for_room(
room_id, limit=load_limit + 1, end_token=end_key
)
+
+ log_kv({"loaded_recents": len(events)})
+
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
+ log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
+
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
@@ -586,6 +595,9 @@ class SyncHandler:
loaded_recents,
always_include_ids=current_state_ids,
)
+
+ log_kv({"loaded_recents_after_client_filtering": len(loaded_recents)})
+
loaded_recents.extend(recents)
recents = loaded_recents
@@ -1116,6 +1128,8 @@ class SyncHandler:
logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder)
+ num_events = 0
+
# debug for https://github.com/matrix-org/synapse/issues/4422
for joined_room in sync_result_builder.joined:
room_id = joined_room.room_id
@@ -1123,6 +1137,14 @@ class SyncHandler:
issue4422_logger.debug(
"Sync result for newly joined room %s: %r", room_id, joined_room
)
+ num_events += len(joined_room.timeline.events)
+
+ log_kv(
+ {
+ "joined_rooms_in_result": len(sync_result_builder.joined),
+ "events_in_result": num_events,
+ }
+ )
logger.debug("Sync response calculation complete")
return SyncResult(
@@ -1467,6 +1489,7 @@ class SyncHandler:
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
have_changed = await self._have_rooms_changed(sync_result_builder)
+ log_kv({"rooms_have_changed": have_changed})
if not have_changed:
tags_by_room = await self.store.get_updated_tags(
user_id, since_token.account_data_key
@@ -1501,25 +1524,30 @@ class SyncHandler:
tags_by_room = await self.store.get_tags_for_user(user_id)
+ log_kv({"rooms_changed": len(room_changes.room_entries)})
+
room_entries = room_changes.room_entries
invited = room_changes.invited
knocked = room_changes.knocked
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
- async def handle_room_entries(room_entry):
- logger.debug("Generating room entry for %s", room_entry.room_id)
- res = await self._generate_room_entry(
- sync_result_builder,
- ignored_users,
- room_entry,
- ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
- tags=tags_by_room.get(room_entry.room_id),
- account_data=account_data_by_room.get(room_entry.room_id, {}),
- always_include=sync_result_builder.full_state,
- )
- logger.debug("Generated room entry for %s", room_entry.room_id)
- return res
+ async def handle_room_entries(room_entry: "RoomSyncResultBuilder"):
+ with start_active_span("generate_room_entry"):
+ set_tag("room_id", room_entry.room_id)
+ log_kv({"events": len(room_entry.events or [])})
+ logger.debug("Generating room entry for %s", room_entry.room_id)
+ res = await self._generate_room_entry(
+ sync_result_builder,
+ ignored_users,
+ room_entry,
+ ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
+ tags=tags_by_room.get(room_entry.room_id),
+ account_data=account_data_by_room.get(room_entry.room_id, {}),
+ always_include=sync_result_builder.full_state,
+ )
+ logger.debug("Generated room entry for %s", room_entry.room_id)
+ return res
await concurrently_execute(handle_room_entries, room_entries, 10)
@@ -1932,6 +1960,12 @@ class SyncHandler:
room_id = room_builder.room_id
since_token = room_builder.since_token
upto_token = room_builder.upto_token
+ log_kv(
+ {
+ "since_token": since_token,
+ "upto_token": upto_token,
+ }
+ )
batch = await self._load_filtered_recents(
room_id,
@@ -1941,6 +1975,13 @@ class SyncHandler:
potential_recents=events,
newly_joined_room=newly_joined,
)
+ log_kv(
+ {
+ "batch_events": len(batch.events),
+ "prev_batch": batch.prev_batch,
+ "batch_limited": batch.limited,
+ }
+ )
# Note: `batch` can be both empty and limited here in the case where
# `_load_filtered_recents` can't find any events the user should see
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a12fa30bfd..91ba93372c 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -572,6 +572,25 @@ def parse_string_from_args(
return strings[0]
+@overload
+def parse_json_value_from_request(request: Request) -> JsonDict:
+ ...
+
+
+@overload
+def parse_json_value_from_request(
+ request: Request, allow_empty_body: Literal[False]
+) -> JsonDict:
+ ...
+
+
+@overload
+def parse_json_value_from_request(
+ request: Request, allow_empty_body: bool = False
+) -> Optional[JsonDict]:
+ ...
+
+
def parse_json_value_from_request(
request: Request, allow_empty_body: bool = False
) -> Optional[JsonDict]:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 941fb238b7..b0834720ad 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -258,7 +258,7 @@ class Mailer:
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
- rooms = []
+ rooms: List[Dict[str, Any]] = []
for r in rooms_in_order:
roomvars = await self._get_room_vars(
@@ -362,6 +362,7 @@ class Mailer:
"notifs": [],
"invite": is_invite,
"link": self._make_room_link(room_id),
+ "avatar_url": await self._get_room_avatar(room_state_ids),
}
if not is_invite:
@@ -393,6 +394,27 @@ class Mailer:
return room_vars
+ async def _get_room_avatar(
+ self,
+ room_state_ids: StateMap[str],
+ ) -> Optional[str]:
+ """
+ Retrieve the avatar url for this room---if it exists.
+
+ Args:
+ room_state_ids: The event IDs of the current room state.
+
+ Returns:
+ room's avatar url if it's present and a string; otherwise None.
+ """
+ event_id = room_state_ids.get((EventTypes.RoomAvatar, ""))
+ if event_id:
+ ev = await self.store.get_event(event_id)
+ url = ev.content.get("url")
+ if isinstance(url, str):
+ return url
+ return None
+
async def _get_notif_vars(
self,
notif: Dict[str, Any],
diff --git a/synapse/res/providers.json b/synapse/res/providers.json
new file mode 100644
index 0000000000..f1838f9559
--- /dev/null
+++ b/synapse/res/providers.json
@@ -0,0 +1,17 @@
+[
+ {
+ "provider_name": "Twitter",
+ "provider_url": "http://www.twitter.com/",
+ "endpoints": [
+ {
+ "schemes": [
+ "https://twitter.com/*/status/*",
+ "https://*.twitter.com/*/status/*",
+ "https://twitter.com/*/moments/*",
+ "https://*.twitter.com/*/moments/*"
+ ],
+ "url": "https://publish.twitter.com/oembed"
+ }
+ ]
+ }
+]
\ No newline at end of file
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 42201afc86..f5a38c2670 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError, SynapseError
@@ -101,7 +101,9 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
- def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
+ def on_PUT(
+ self, request: SynapseRequest, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 0443f4571c..a0971ce994 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -16,7 +16,7 @@
"""
import logging
import re
-from typing import Iterable, Pattern
+from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
)
-def interactive_auth_handler(orig):
+C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]])
+
+
+def interactive_auth_handler(orig: C) -> C:
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns an Awaitable (errcode, body) response
@@ -91,10 +94,10 @@ def interactive_auth_handler(orig):
await self.auth_handler.check_auth
"""
- async def wrapped(*args, **kwargs):
+ async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]:
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result
- return wrapped
+ return cast(C, wrapped)
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index fb5ad2906e..aefaaa8ae8 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -16,9 +16,11 @@
import logging
import random
from http import HTTPStatus
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple
from urllib.parse import urlparse
+from twisted.web.server import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@@ -28,15 +30,17 @@ from synapse.api.errors import (
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed, validate_email
@@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
template_text=self.config.email_password_reset_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet):
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these
@@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet):
#
# In the second case, we require a password to confirm their identity.
+ requester = None
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
try:
@@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
- password_hash = await self.auth_handler.hash(new_password)
+ new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
- password_hash,
+ new_password_hash,
)
raise
user_id = requester.user.to_string()
else:
- requester = None
try:
result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]],
@@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet):
# If a password is available now, hash the provided password and
# store it for later.
if new_password:
- password_hash = await self.auth_handler.hash(new_password)
+ new_password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
- password_hash,
+ new_password_hash,
)
raise
@@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet):
# If we have a password in this request, prefer it. Otherwise, use the
# password hash from an earlier request.
if new_password:
- password_hash = await self.auth_handler.hash(new_password)
+ password_hash: Optional[str] = await self.auth_handler.hash(new_password)
elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
@@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.config = hs.config
@@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
template_text=self.config.email_add_threepid_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
@@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
"/add_threepid/email/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
@@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config.email_add_threepid_template_failure_html
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> None:
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
if not self.config.account_threepid_delegate_msisdn:
raise SynapseError(
400,
@@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet):
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet):
class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
@@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
@@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
-def assert_valid_next_link(hs: "HomeServer", next_link: str):
+def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
"""
Raises a SynapseError if a given next_link value is invalid
@@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str):
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
response = {"user_id": requester.user.to_string()}
@@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet):
return 200, response
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index 7517e9304e..d1badbdf3b 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -13,12 +13,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, account_data_type):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, account_data_type: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet):
return 200, {}
- async def on_GET(self, request, user_id, account_data_type):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, account_data_type: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
@@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet):
"/account_data/(?P<account_data_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
- async def on_PUT(self, request, user_id, room_id, account_data_type):
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet):
return 200, {}
- async def on_GET(self, request, user_id, room_id, account_data_type):
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ room_id: str,
+ account_data_type: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
@@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet):
return 200, event
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index c3667ff8aa..a7e9aa3e9b 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -15,7 +15,7 @@
import logging
from functools import wraps
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from twisted.web.server import Request
@@ -43,14 +43,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _validate_group_id(f):
+def _validate_group_id(
+ f: Callable[..., Awaitable[Tuple[int, JsonDict]]]
+) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]:
"""Wrapper to validate the form of the group ID.
Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
"""
@wraps(f)
- def wrapper(self, request: Request, group_id: str, *args, **kwargs):
+ def wrapper(
+ self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any
+ ) -> Awaitable[Tuple[int, JsonDict]]:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
@@ -156,7 +160,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
group_id: str,
category_id: Optional[str],
room_id: str,
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -188,7 +192,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -451,7 +455,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -674,7 +678,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -706,7 +710,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: SynapseRequest, group_id, user_id
+ self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -738,7 +742,7 @@ class GroupAdminUsersKickServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: SynapseRequest, group_id, user_id
+ self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index 68fb08d0ba..0152a0c66a 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from twisted.web.server import Request
@@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id}
- def on_PUT(self, request: Request, room_identifier: str, txn_id: str):
+ def on_PUT(
+ self, request: Request, room_identifier: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 702b351d18..fb3211bf3a 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,22 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
+
+import attr
+
from synapse.api.errors import (
NotFoundError,
StoreError,
SynapseError,
UnrecognizedRequestError,
)
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_value_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RuleSpec:
+ scope: str
+ template: str
+ rule_id: str
+ attr: Optional[str]
class PushRuleRestServlet(RestServlet):
@@ -36,7 +54,7 @@ class PushRuleRestServlet(RestServlet):
"Unrecognised request: You probably wanted a trailing slash"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -45,7 +63,7 @@ class PushRuleRestServlet(RestServlet):
self._users_new_default_push_rules = hs.config.users_new_default_push_rules
- async def on_PUT(self, request, path):
+ async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
@@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
- if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
+ if "/" in spec.rule_id or "\\" in spec.rule_id:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
- if "attr" in spec:
+ if spec.attr:
await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
return 200, {}
- if spec["rule_id"].startswith("."):
+ if spec.rule_id.startswith("."):
# Rule ids starting with '.' are reserved for server default rules.
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
try:
(conditions, actions) = _rule_tuple_from_request_object(
- spec["template"], spec["rule_id"], content
+ spec.template, spec.rule_id, content
)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
@@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
- async def on_DELETE(self, request, path):
+ async def on_DELETE(
+ self, request: SynapseRequest, path: str
+ ) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
@@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet):
else:
raise
- async def on_GET(self, request, path):
+ async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet):
rules = format_push_rules_for_user(requester.user, rules)
- path = path.split("/")[1:]
+ path_parts = path.split("/")[1:]
- if path == []:
+ if path_parts == []:
# we're a reference impl: pedantry is our job.
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == "":
+ if path_parts[0] == "":
return 200, rules
- elif path[0] == "global":
- result = _filter_ruleset_with_path(rules["global"], path[1:])
+ elif path_parts[0] == "global":
+ result = _filter_ruleset_with_path(rules["global"], path_parts[1:])
return 200, result
else:
raise UnrecognizedRequestError()
- def notify_user(self, user_id):
+ def notify_user(self, user_id: str) -> None:
stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
- async def set_rule_attr(self, user_id, spec, val):
- if spec["attr"] not in ("enabled", "actions"):
+ async def set_rule_attr(
+ self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
+ ) -> None:
+ if spec.attr not in ("enabled", "actions"):
# for the sake of potential future expansion, shouldn't report
# 404 in the case of an unknown request so check it corresponds to
# a known attribute first.
raise UnrecognizedRequestError()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec["rule_id"]
+ rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
- if spec["attr"] == "enabled":
+ if spec.attr == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
@@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
- return await self.store.set_push_rule_enabled(
+ await self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val, is_default_rule
)
- elif spec["attr"] == "actions":
+ elif spec.attr == "actions":
+ if not isinstance(val, dict):
+ raise SynapseError(400, "Value must be a dict")
actions = val.get("actions")
+ if not isinstance(actions, list):
+ raise SynapseError(400, "Value for 'actions' must be dict")
_check_actions(actions)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec["rule_id"]
+ rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if user_id in self._users_new_default_push_rules:
@@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet):
if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
- return await self.store.set_push_rule_actions(
+ await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
raise UnrecognizedRequestError()
-def _rule_spec_from_path(path):
+def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
Args:
- path (sequence[unicode]): the URL path components.
+ path: the URL path components.
Returns:
- dict: rule spec dict, containing scope/template/rule_id entries,
- and possibly attr.
+ rule spec, containing scope/template/rule_id entries, and possibly attr.
Raises:
UnrecognizedRequestError if the path components cannot be parsed.
@@ -237,17 +262,18 @@ def _rule_spec_from_path(path):
rule_id = path[0]
- spec = {"scope": scope, "template": template, "rule_id": rule_id}
-
path = path[1:]
+ attr = None
if len(path) > 0 and len(path[0]) > 0:
- spec["attr"] = path[0]
+ attr = path[0]
- return spec
+ return RuleSpec(scope, template, rule_id, attr)
-def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
+def _rule_tuple_from_request_object(
+ rule_template: str, rule_id: str, req_obj: JsonDict
+) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]:
if rule_template in ["override", "underride"]:
if "conditions" not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
@@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
return conditions, actions
-def _check_actions(actions):
+def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
@@ -290,7 +316,7 @@ def _check_actions(actions):
raise InvalidRuleException("Unrecognised action")
-def _filter_ruleset_with_path(ruleset, path):
+def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
@@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path):
if r["rule_id"] == rule_id:
the_rule = r
if the_rule is None:
- raise NotFoundError
+ raise NotFoundError()
path = path[1:]
if len(path) == 0:
@@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path):
raise UnrecognizedRequestError()
-def _priority_class_from_spec(spec):
- if spec["template"] not in PRIORITY_CLASS_MAP.keys():
- raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
- pc = PRIORITY_CLASS_MAP[spec["template"]]
+def _priority_class_from_spec(spec: RuleSpec) -> int:
+ if spec.template not in PRIORITY_CLASS_MAP.keys():
+ raise InvalidRuleException("Unknown template: %s" % (spec.template))
+ pc = PRIORITY_CLASS_MAP[spec.template]
return pc
-def _namespaced_rule_id_from_spec(spec):
- return _namespaced_rule_id(spec, spec["rule_id"])
+def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
+ return _namespaced_rule_id(spec, spec.rule_id)
-def _namespaced_rule_id(spec, rule_id):
- return "global/%s/%s" % (spec["template"], rule_id)
+def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
+ return "global/%s/%s" % (spec.template, rule_id)
class InvalidRuleException(Exception):
pass
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushRuleRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index d9ab836cd8..9770413c61 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -13,13 +13,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet):
"/(?P<event_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
- async def on_POST(self, request, room_id, receipt_type, event_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
@@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 7b5f49d635..8f3dd2a101 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -14,7 +14,9 @@
# limitations under the License.
import logging
import random
-from typing import List, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.server import Request
import synapse
import synapse.api.auth
@@ -29,15 +31,13 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
-from synapse.config.captcha import CaptchaConfig
-from synapse.config.consent import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
-from synapse.http.server import finish_request, respond_with_html
+from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.types import JsonDict
@@ -59,17 +60,16 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=self.config.email_registration_template_text,
)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/msisdn/requestToken$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
@@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
"/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.config.email_registration_template_failure_html
)
- async def on_GET(self, request, medium):
+ async def on_GET(self, request: Request, medium: str) -> None:
if medium != "email":
raise SynapseError(
400, "This medium is currently not supported for registration"
@@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_patterns("/register/available")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
@@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@@ -387,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
unstable=True,
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
@@ -402,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
)
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
if not self.hs.config.enable_registration:
@@ -419,11 +403,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -445,23 +425,21 @@ class RegisterRestServlet(RestServlet):
)
@interactive_auth_handler
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
await self.ratelimiter.ratelimit(None, client_addr, update=False)
- kind = b"user"
- if b"kind" in request.args:
- kind = request.args[b"kind"][0]
+ kind = parse_string(request, "kind", default="user")
- if kind == b"guest":
+ if kind == "guest":
ret = await self._do_guest_registration(body, address=client_addr)
return ret
- elif kind != b"user":
+ elif kind != "user":
raise UnrecognizedRequestError(
- "Do not understand membership kind: %s" % (kind.decode("utf8"),)
+ f"Do not understand membership kind: {kind}",
)
if self._msc2918_enabled:
@@ -748,8 +726,12 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict
async def _do_appservice_registration(
- self, username, as_token, body, should_issue_refresh_token: bool = False
- ):
+ self,
+ username: str,
+ as_token: str,
+ body: JsonDict,
+ should_issue_refresh_token: bool = False,
+ ) -> JsonDict:
user_id = await self.registration_handler.appservice_register(
username, as_token
)
@@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
params: JsonDict,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
- ):
+ ) -> JsonDict:
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
return result
- async def _do_guest_registration(self, params, address=None):
+ async def _do_guest_registration(
+ self, params: JsonDict, address: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user(
@@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
def _calculate_registration_flows(
- # technically `config` has to provide *all* of these interfaces, not just one
- config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
- auth_handler: AuthHandler,
+ config: HomeServerConfig, auth_handler: AuthHandler
) -> List[List[str]]:
"""Get a suitable flows list for registration
@@ -929,7 +911,7 @@ def _calculate_registration_flows(
return flows
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 0821cd285f..0b0711c03c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,25 +19,32 @@ any time to reflect changes in the MSC.
"""
import logging
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import ShadowBanError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet):
"/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
http_server.register_paths(
"POST",
client_patterns(self.PATTERN + "$", releases=()),
@@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet):
self.__class__.__name__,
)
- def on_PUT(self, request, *args, **kwargs):
+ def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
- request, self.on_PUT_or_POST, request, *args, **kwargs
+ request,
+ self.on_PUT_or_POST,
+ request,
+ room_id,
+ parent_id,
+ relation_type,
+ event_type,
+ txn_id,
)
async def on_PUT_or_POST(
- self, request, room_id, parent_id, relation_type, event_type, txn_id=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
@@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet):
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet):
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
@@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ key: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return 200, return_value
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index 07ea39a8a3..d4a4adb50c 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/report_event.py
@@ -14,26 +14,35 @@
import logging
from http import HTTPStatus
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- async def on_POST(self, request, room_id, event_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReportEventRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index c5c54564be..9b0c546505 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,11 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse
+from twisted.web.server import Request
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -30,6 +32,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
@@ -57,7 +60,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.txns = HttpTransactionCache(hs)
@@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
- def on_PUT(self, request, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
info, _ = await self._room_creation_handler.create_room(
@@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet):
return 200, info
- def get_room_config(self, request):
+ def get_room_config(self, request: Request) -> JsonDict:
user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
self.__class__.__name__,
)
- def on_GET_no_state_key(self, request, room_id, event_type):
+ def on_GET_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_GET(request, room_id, event_type, "")
- def on_PUT_no_state_key(self, request, room_id, event_type):
+ def on_PUT_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "")
- async def on_GET(self, request, room_id, event_type, state_key):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
@@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
- async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ # Format must be event or content, per the parse_string call above.
+ raise RuntimeError(f"Unknown format: {format:r}.")
+
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if txn_id:
@@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- async def on_POST(self, request, room_id, event_type, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- event_dict = {
+ event_dict: JsonDict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
+ # Twisted will have processed the args by now.
+ assert request.args is not None
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
@@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_GET(self, request, room_id, event_type, txn_id):
+ def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Tuple[int, str]:
return 200, "Not implemented"
- def on_PUT(self, request, room_id, event_type, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /join/$room_identifier[/$txn_id]
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
@@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
- def on_PUT(self, request, room_identifier, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_identifier: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server")
try:
@@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
return 200, data
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server")
@@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet):
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
# TODO support Pagination stream API (limit/tokens)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
@@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet):
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
users_with_profile = await self.message_handler.get_joined_members(
@@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet):
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
)
+ # Twisted will have processed the args by now.
+ assert request.args is not None
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
@@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet):
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, List[JsonDict]]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
events = await self.message_handler.get_state_events(
@@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet):
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
@@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
event = await self.event_handler.get_event(
@@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event = await self._event_serializer.serialize_event(event, time_now)
- return 200, event
+ event_dict = await self._event_serializer.serialize_event(event, time_now)
+ return 200, event_dict
- return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
class RoomEventContextServlet(RestServlet):
@@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
- def on_PUT(self, request, room_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/[invite|join|leave]
PATTERNS = (
"/rooms/(?P<room_id>[^/]*)/"
@@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ membership_action: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
@@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
- def _has_3pid_invite_keys(self, content):
+ def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
for key in {"id_server", "medium", "address"}:
if key not in content:
return False
return True
- def on_PUT(self, request, room_id, membership_action, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, event_id, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_id: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_PUT(self, request, room_id, event_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet):
hs.config.worker.writers.typing == hs.get_instance_name()
)
- async def on_PUT(self, request, room_id, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not self._is_typing_writer:
@@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet):
self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
alias_list = await self.directory_handler.get_aliases_for_room(
@@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
-def register_txn_path(servlet, regex_string, http_server, with_get=False):
+def register_txn_path(
+ servlet: RestServlet,
+ regex_string: str,
+ http_server: HttpServer,
+ with_get: bool = False,
+) -> None:
"""Registers a transaction-based path.
This registers two paths:
@@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
POST regex_string
Args:
- regex_string (str): The regex string to register. Must NOT have a
- trailing $ as this string will be appended to.
- http_server : The http_server to register paths with.
+ regex_string: The regex string to register. Must NOT have a
+ trailing $ as this string will be appended to.
+ http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
+ on_POST = getattr(servlet, "on_POST", None)
+ on_PUT = getattr(servlet, "on_PUT", None)
+ if on_POST is None or on_PUT is None:
+ raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
http_server.register_paths(
"POST",
client_patterns(regex_string + "$", v1=True),
- servlet.on_POST,
+ on_POST,
servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_PUT,
+ on_PUT,
servlet.__class__.__name__,
)
+ on_GET = getattr(servlet, "on_GET", None)
if with_get:
+ if on_GET is None:
+ raise RuntimeError(
+ "register_txn_path called with with_get = True, but no on_GET method exists"
+ )
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_GET,
+ on_GET,
servlet.__class__.__name__,
)
@@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
)
-def register_servlets(hs: "HomeServer", http_server, is_worker=False):
+def register_servlets(
+ hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
+) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomForgetRestServlet(hs).register(http_server)
-def register_deprecated_servlets(hs, http_server):
+def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomInitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 3172aba605..ed96978448 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,10 +14,14 @@
import logging
import re
+from typing import TYPE_CHECKING, Awaitable, List, Tuple
+
+from twisted.web.server import Request
from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.appservice import ApplicationService
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -25,10 +29,14 @@ from synapse.http.servlet import (
parse_string,
parse_strings_from_args,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.store = hs.get_datastore()
@@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
- async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+ async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
(
most_recent_prev_event_id,
most_recent_prev_event_depth,
@@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
def _create_insertion_event_dict(
self, sender: str, room_id: str, origin_server_ts: int
- ):
+ ) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields
and a random chunk ID.
@@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
origin_server_ts: Timestamp when the event was sent
Returns:
- Tuple of event ID and stream ordering position
+ The new event dictionary to insert.
"""
next_chunk_id = random_string(8)
@@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet):
return create_requester(user_id, app_service=app_service)
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
if not requester.app_service:
@@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["state_events_at_start", "events"])
+ assert request.args is not None
prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
chunk_id_from_query = parse_string(request, "chunk_id")
@@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet):
],
}
- def on_GET(self, request, room_id):
+ def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
return 501, "Not implemented"
- def on_PUT(self, request, room_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id
)
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
msc2716_enabled = hs.config.experimental.msc2716_enabled
if msc2716_enabled:
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 263596be86..37e39570f6 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,16 +13,23 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet):
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- async def on_PUT(self, request, room_id, session_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional)
@@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet):
ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
return 200, ret
- async def on_GET(self, request, room_id, session_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API.
@@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet):
return 200, room_keys
- async def on_DELETE(self, request, room_id, session_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Deletes one or more encrypted E2E room keys for a user for backup purposes.
@@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user
@@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- async def on_GET(self, request, version):
+ async def on_GET(
+ self, request: SynapseRequest, version: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the
@@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet):
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
return 200, info
- async def on_DELETE(self, request, version):
+ async def on_DELETE(
+ self, request: SynapseRequest, version: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most
@@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet):
await self.e2e_room_keys_handler.delete_version(user_id, version)
return 200, {}
- async def on_PUT(self, request, version):
+ async def on_PUT(
+ self, request: SynapseRequest, version: Optional[str]
+ ) -> Tuple[int, JsonDict]:
"""
Update the information about a given version of the user's room_keys backup.
@@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet):
return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomKeysServlet(hs).register(http_server)
RoomKeysVersionServlet(hs).register(http_server)
RoomKeysNewVersionServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index d537d811d8..3322c8ef48 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -13,15 +13,21 @@
# limitations under the License.
import logging
-from typing import Tuple
+from typing import TYPE_CHECKING, Awaitable, Tuple
from synapse.http import servlet
+from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
)
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.device_message_handler = hs.get_device_message_handler()
@trace(opname="sendToDevice")
- def on_PUT(self, request, message_type, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, message_type: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("message_type", message_type)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
- async def _put(self, request, message_type, txn_id):
+ async def _put(
+ self, request: SynapseRequest, message_type: str, txn_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester, message_type, content["messages"]
)
- response: Tuple[int, dict] = (200, {})
- return response
+ return 200, {}
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SendToDeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 65c37be3e9..1259058b9b 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,12 +14,24 @@
import itertools
import logging
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.api.presence import UserPresenceState
+from synapse.events import EventBase
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
@@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet):
The room, encoded in our response format
"""
- def serialize(events):
+ def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]:
return self._event_serializer.serialize_events(
events,
time_now=time_now,
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 94ff3719ce..914fb3acf5 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -15,28 +15,37 @@
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
+
+from twisted.python.failure import Failure
+from twisted.web.server import Request
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import JsonDict
from synapse.util.async_helpers import ObservableDeferred
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
- self.transactions = {
- # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
- }
+ # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
+ self.transactions: Dict[
+ str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
+ ] = {}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
- def _get_transaction_key(self, request):
+ def _get_transaction_key(self, request: Request) -> str:
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
@@ -45,15 +54,21 @@ class HttpTransactionCache:
path and the access_token for the requesting user.
Args:
- request (twisted.web.http.Request): The incoming request. Must
- contain an access_token.
+ request: The incoming request. Must contain an access_token.
Returns:
- str: A transaction key
+ A transaction key
"""
+ assert request.path is not None
token = self.auth.get_access_token_from_request(request)
return request.path.decode("utf8") + "/" + token
- def fetch_or_execute_request(self, request, fn, *args, **kwargs):
+ def fetch_or_execute_request(
+ self,
+ request: Request,
+ fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
+ *args: Any,
+ **kwargs: Any,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@@ -64,15 +79,20 @@ class HttpTransactionCache:
self._get_transaction_key(request), fn, *args, **kwargs
)
- def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
+ def fetch_or_execute(
+ self,
+ txn_key: str,
+ fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
+ *args: Any,
+ **kwargs: Any,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
- txn_key (str): A key to ensure idempotency should fetch_or_execute be
- called again at a later point in time.
- fn (function): A function which returns a tuple of
- (response_code, response_dict).
+ txn_key: A key to ensure idempotency should fetch_or_execute be
+ called again at a later point in time.
+ fn: A function which returns a tuple of (response_code, response_dict).
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
@@ -90,7 +110,7 @@ class HttpTransactionCache:
# if the request fails with an exception, remove it
# from the transaction map. This is done to ensure that we don't
# cache transient errors like rate-limiting errors, etc.
- def remove_from_map(err):
+ def remove_from_map(err: Failure) -> None:
self.transactions.pop(txn_key, None)
# we deliberately do not propagate the error any further, as we
# expect the observers to have reported it.
@@ -99,7 +119,7 @@ class HttpTransactionCache:
return make_deferred_yieldable(observable.observe())
- def _cleanup(self):
+ def _cleanup(self) -> None:
now = self.clock.time_msec()
for key in list(self.transactions):
ts = self.transactions[key][1]
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
new file mode 100644
index 0000000000..afe41823e4
--- /dev/null
+++ b/synapse/rest/media/v1/oembed.py
@@ -0,0 +1,135 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+import attr
+
+from synapse.http.client import SimpleHttpClient
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, auto_attribs=True)
+class OEmbedResult:
+ # Either HTML content or URL must be provided.
+ html: Optional[str]
+ url: Optional[str]
+ title: Optional[str]
+ # Number of seconds to cache the content.
+ cache_age: int
+
+
+class OEmbedError(Exception):
+ """An error occurred processing the oEmbed object."""
+
+
+class OEmbedProvider:
+ """
+ A helper for accessing oEmbed content.
+
+ It can be used to check if a URL should be accessed via oEmbed and for
+ requesting/parsing oEmbed content.
+ """
+
+ def __init__(self, hs: "HomeServer", client: SimpleHttpClient):
+ self._oembed_patterns = {}
+ for oembed_endpoint in hs.config.oembed.oembed_patterns:
+ for pattern in oembed_endpoint.url_patterns:
+ self._oembed_patterns[pattern] = oembed_endpoint.api_endpoint
+ self._client = client
+
+ def get_oembed_url(self, url: str) -> Optional[str]:
+ """
+ Check whether the URL should be downloaded as oEmbed content instead.
+
+ Args:
+ url: The URL to check.
+
+ Returns:
+ A URL to use instead or None if the original URL should be used.
+ """
+ for url_pattern, endpoint in self._oembed_patterns.items():
+ if url_pattern.fullmatch(url):
+ return endpoint
+
+ # No match.
+ return None
+
+ async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ """
+ Request content from an oEmbed endpoint.
+
+ Args:
+ endpoint: The oEmbed API endpoint.
+ url: The URL to pass to the API.
+
+ Returns:
+ An object representing the metadata returned.
+
+ Raises:
+ OEmbedError if fetching or parsing of the oEmbed information fails.
+ """
+ try:
+ logger.debug("Trying to get oEmbed content for url '%s'", url)
+ result = await self._client.get_json(
+ endpoint,
+ # TODO Specify max height / width.
+ # Note that only the JSON format is supported.
+ args={"url": url},
+ )
+
+ # Ensure there's a version of 1.0.
+ if result.get("version") != "1.0":
+ raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+ oembed_type = result.get("type")
+
+ # Ensure the cache age is None or an int.
+ cache_age = result.get("cache_age")
+ if cache_age:
+ cache_age = int(cache_age)
+
+ oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+ # HTML content.
+ if oembed_type == "rich":
+ oembed_result.html = result.get("html")
+ return oembed_result
+
+ if oembed_type == "photo":
+ oembed_result.url = result.get("url")
+ return oembed_result
+
+ # TODO Handle link and video types.
+
+ if "thumbnail_url" in result:
+ oembed_result.url = result.get("thumbnail_url")
+ return oembed_result
+
+ raise OEmbedError("Incompatible oEmbed information.")
+
+ except OEmbedError as e:
+ # Trap OEmbedErrors first so we can directly re-raise them.
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+ raise
+
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+ raise OEmbedError() from e
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0f051d4041..317d333b12 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -25,8 +25,6 @@ import traceback
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
-import attr
-
from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
@@ -43,6 +41,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@@ -71,63 +70,6 @@ OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
-# A map of globs to API endpoints.
-_oembed_globs = {
- # Twitter.
- "https://publish.twitter.com/oembed": [
- "https://twitter.com/*/status/*",
- "https://*.twitter.com/*/status/*",
- "https://twitter.com/*/moments/*",
- "https://*.twitter.com/*/moments/*",
- # Include the HTTP versions too.
- "http://twitter.com/*/status/*",
- "http://*.twitter.com/*/status/*",
- "http://twitter.com/*/moments/*",
- "http://*.twitter.com/*/moments/*",
- ],
-}
-# Convert the globs to regular expressions.
-_oembed_patterns = {}
-for endpoint, globs in _oembed_globs.items():
- for glob in globs:
- # Convert the glob into a sane regular expression to match against. The
- # rules followed will be slightly different for the domain portion vs.
- # the rest.
- #
- # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
- # 2. The domain can have globs, but we limit it to characters that can
- # reasonably be a domain part.
- # TODO: This does not attempt to handle Unicode domain names.
- # 3. Other parts allow a glob to be any one, or more, characters.
- results = urlparse.urlparse(glob)
-
- # Ensure the scheme does not have wildcards (and is a sane scheme).
- if results.scheme not in {"http", "https"}:
- raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
-
- pattern = urlparse.urlunparse(
- [
- results.scheme,
- re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
- ]
- + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
- )
- _oembed_patterns[re.compile(pattern)] = endpoint
-
-
-@attr.s(slots=True)
-class OEmbedResult:
- # Either HTML content or URL must be provided.
- html = attr.ib(type=Optional[str])
- url = attr.ib(type=Optional[str])
- title = attr.ib(type=Optional[str])
- # Number of seconds to cache the content.
- cache_age = attr.ib(type=int)
-
-
-class OEmbedError(Exception):
- """An error occurred processing the oEmbed object."""
-
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
@@ -157,6 +99,8 @@ class PreviewUrlResource(DirectServeJsonResource):
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
+ self._oembed = OEmbedProvider(hs, self.client)
+
# We run the background jobs if we're the instance specified (or no
# instance is specified, where we assume there is only one instance
# serving media).
@@ -367,87 +311,6 @@ class PreviewUrlResource(DirectServeJsonResource):
return jsonog.encode("utf8")
- def _get_oembed_url(self, url: str) -> Optional[str]:
- """
- Check whether the URL should be downloaded as oEmbed content instead.
-
- Args:
- url: The URL to check.
-
- Returns:
- A URL to use instead or None if the original URL should be used.
- """
- for url_pattern, endpoint in _oembed_patterns.items():
- if url_pattern.fullmatch(url):
- return endpoint
-
- # No match.
- return None
-
- async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
- """
- Request content from an oEmbed endpoint.
-
- Args:
- endpoint: The oEmbed API endpoint.
- url: The URL to pass to the API.
-
- Returns:
- An object representing the metadata returned.
-
- Raises:
- OEmbedError if fetching or parsing of the oEmbed information fails.
- """
- try:
- logger.debug("Trying to get oEmbed content for url '%s'", url)
- result = await self.client.get_json(
- endpoint,
- # TODO Specify max height / width.
- # Note that only the JSON format is supported.
- args={"url": url},
- )
-
- # Ensure there's a version of 1.0.
- if result.get("version") != "1.0":
- raise OEmbedError("Invalid version: %s" % (result.get("version"),))
-
- oembed_type = result.get("type")
-
- # Ensure the cache age is None or an int.
- cache_age = result.get("cache_age")
- if cache_age:
- cache_age = int(cache_age)
-
- oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
-
- # HTML content.
- if oembed_type == "rich":
- oembed_result.html = result.get("html")
- return oembed_result
-
- if oembed_type == "photo":
- oembed_result.url = result.get("url")
- return oembed_result
-
- # TODO Handle link and video types.
-
- if "thumbnail_url" in result:
- oembed_result.url = result.get("thumbnail_url")
- return oembed_result
-
- raise OEmbedError("Incompatible oEmbed information.")
-
- except OEmbedError as e:
- # Trap OEmbedErrors first so we can directly re-raise them.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- raise
-
- except Exception as e:
- # Trap any exception and let the code follow as usual.
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
- raise OEmbedError() from e
-
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@@ -459,11 +322,11 @@ class PreviewUrlResource(DirectServeJsonResource):
# If this URL can be accessed via oEmbed, use that instead.
url_to_download: Optional[str] = url
- oembed_url = self._get_oembed_url(url)
+ oembed_url = self._oembed.get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
try:
- oembed_result = await self._get_oembed_content(oembed_url, url)
+ oembed_result = await self._oembed.get_oembed_content(oembed_url, url)
if oembed_result.url:
url_to_download = oembed_result.url
elif oembed_result.html:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 95d2caff62..0084d9f96c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -280,18 +280,18 @@ class LoggingTransaction:
else:
self.executemany(sql, args)
- def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+ def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
- Always sets fetch=True when caling `execute_values`, so will return the
- results.
+ The `fetch` parameter must be set to False if the query does not return
+ rows (e.g. INSERTs).
"""
assert isinstance(self.database_engine, PostgresEngine)
from psycopg2.extras import execute_values # type: ignore
return self._do_execute(
- lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+ lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
)
def execute(self, sql: str, *args: Any) -> None:
@@ -920,13 +920,23 @@ class DatabasePool:
if k != keys[0]:
raise RuntimeError("All items must have the same keys")
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0]),
- )
+ if isinstance(txn.database_engine, PostgresEngine):
+ # We use `execute_values` as it can be a lot faster than `execute_batch`,
+ # but it's only available on postgres.
+ sql = "INSERT INTO %s (%s) VALUES ?" % (
+ table,
+ ", ".join(k for k in keys[0]),
+ )
- txn.execute_batch(sql, vals)
+ txn.execute_values(sql, vals, fetch=False)
+ else:
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys[0]),
+ ", ".join("?" for _ in keys[0]),
+ )
+
+ txn.execute_batch(sql, vals)
async def simple_upsert(
self,
@@ -1281,20 +1291,33 @@ class DatabasePool:
k + "=EXCLUDED." + k for k in value_names
)
- sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
- table,
- ", ".join(k for k in allnames),
- ", ".join("?" for _ in allnames),
- ", ".join(key_names),
- latter,
- )
-
args = []
for x, y in zip(key_values, value_values):
args.append(tuple(x) + tuple(y))
- return txn.execute_batch(sql, args)
+ if isinstance(txn.database_engine, PostgresEngine):
+ # We use `execute_values` as it can be a lot faster than `execute_batch`,
+ # but it's only available on postgres.
+ sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % (
+ table,
+ ", ".join(k for k in allnames),
+ ", ".join(key_names),
+ latter,
+ )
+
+ txn.execute_values(sql, args, fetch=False)
+
+ else:
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+ table,
+ ", ".join(k for k in allnames),
+ ", ".join("?" for _ in allnames),
+ ", ".join(key_names),
+ latter,
+ )
+
+ return txn.execute_batch(sql, args)
@overload
async def simple_select_one(
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 86075bc55b..6daf8b8ffb 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -75,8 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore):
desc="get_aliases_for_room",
)
-
-class DirectoryStore(DirectoryWorkerStore):
async def create_room_alias_association(
self,
room_alias: RoomAlias,
@@ -126,6 +124,8 @@ class DirectoryStore(DirectoryWorkerStore):
409, "Room alias %s already exists" % room_alias.to_string()
)
+
+class DirectoryStore(DirectoryWorkerStore):
async def delete_room_alias(self, room_alias: RoomAlias) -> str:
room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 40b53274fb..f07e288056 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -575,7 +575,13 @@ class PersistEventsStore:
missing_auth_chains.clear()
- for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+ for (
+ auth_id,
+ event_type,
+ state_key,
+ chain_id,
+ sequence_number,
+ ) in txn.fetchall():
event_to_types[auth_id] = (event_type, state_key)
if chain_id is None:
@@ -1379,18 +1385,18 @@ class PersistEventsStore:
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.
- sql = "UPDATE redactions SET have_censored = ? WHERE redacts = ?"
- txn.execute_batch(
- sql,
- (
- (
- False,
- event.event_id,
- )
- for event, _ in events_and_contexts
- if not event.internal_metadata.is_redacted()
- ),
+ unredacted_events = [
+ event.event_id
+ for event, _ in events_and_contexts
+ if not event.internal_metadata.is_redacted()
+ ]
+ sql = "UPDATE redactions SET have_censored = ? WHERE "
+ clause, args = make_in_list_sql_clause(
+ self.database_engine,
+ "redacts",
+ unredacted_events,
)
+ txn.execute(sql + clause, [False] + args)
state_events_and_contexts = [
ec for ec in events_and_contexts if ec[0].is_state()
@@ -1770,10 +1776,21 @@ class PersistEventsStore:
# Not a insertion event
return
- # Skip processing a insertion event if the room version doesn't
- # support it.
+ # Skip processing an insertion event if the room version doesn't
+ # support it or the event is not from the room creator.
room_version = self.store.get_room_version_txn(txn, event.room_id)
- if not room_version.msc2716_historical:
+ room_creator = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": event.room_id},
+ retcol="creator",
+ allow_none=True,
+ )
+ if (
+ not room_version.msc2716_historical
+ or not self.hs.config.experimental.msc2716_enabled
+ or event.sender != room_creator
+ ):
return
next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID)
@@ -1822,9 +1839,20 @@ class PersistEventsStore:
return
# Skip processing a chunk event if the room version doesn't
- # support it.
+ # support it or the event is not from the room creator.
room_version = self.store.get_room_version_txn(txn, event.room_id)
- if not room_version.msc2716_historical:
+ room_creator = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": event.room_id},
+ retcol="creator",
+ allow_none=True,
+ )
+ if (
+ not room_version.msc2716_historical
+ or not self.hs.config.experimental.msc2716_enabled
+ or event.sender != room_creator
+ ):
return
chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 1388771c40..12cf6995eb 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -29,7 +29,26 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-class PresenceStore(SQLBaseStore):
+class PresenceBackgroundUpdateStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: Connection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ # Used by `PresenceStore._get_active_presence()`
+ self.db_pool.updates.register_background_index_update(
+ "presence_stream_not_offline_index",
+ index_name="presence_stream_state_not_offline_idx",
+ table="presence_stream",
+ columns=["state"],
+ where_clause="state != 'offline'",
+ )
+
+
+class PresenceStore(PresenceBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
@@ -332,6 +351,8 @@ class PresenceStore(SQLBaseStore):
the appropriate time outs.
"""
+ # The `presence_stream_state_not_offline_idx` index should be used for this
+ # query.
sql = (
"SELECT user_id, state, last_active_ts, last_federation_update_ts,"
" last_user_sync_ts, status_msg, currently_active FROM presence_stream"
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f98b892598..6e7312266d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -19,9 +19,10 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore
@@ -1013,6 +1014,7 @@ class _BackgroundUpdates:
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
+ POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
_REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1054,6 +1056,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_replace_room_depth_min_depth,
)
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
+ self._background_populate_rooms_creator_column,
+ )
+
async def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
@@ -1273,7 +1280,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
keyvalues={"room_id": room_id},
retcol="MAX(stream_ordering)",
allow_none=True,
- desc="upsert_room_on_join",
+ desc="has_auth_chain_index_fallback",
)
return max_ordering is None
@@ -1343,6 +1350,65 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return 0
+ async def _background_populate_rooms_creator_column(
+ self, progress: dict, batch_size: int
+ ):
+ """Background update to go and add creator information to `rooms`
+ table from `current_state_events` table.
+ """
+
+ last_room_id = progress.get("room_id", "")
+
+ def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
+ sql = """
+ SELECT room_id, json FROM event_json
+ INNER JOIN rooms AS room USING (room_id)
+ INNER JOIN current_state_events AS state_event USING (room_id, event_id)
+ WHERE room_id > ? AND (room.creator IS NULL OR room.creator = '') AND state_event.type = 'm.room.create' AND state_event.state_key = ''
+ ORDER BY room_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+ room_id_to_create_event_results = txn.fetchall()
+
+ new_last_room_id = ""
+ for room_id, event_json in room_id_to_create_event_results:
+ event_dict = db_to_json(event_json)
+
+ creator = event_dict.get("content").get(EventContentFields.ROOM_CREATOR)
+
+ self.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"creator": creator},
+ )
+ new_last_room_id = room_id
+
+ if new_last_room_id == "":
+ return True
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
+ {"room_id": new_last_room_id},
+ )
+
+ return False
+
+ end = await self.db_pool.runInteraction(
+ "_background_populate_rooms_creator_column",
+ _background_populate_rooms_creator_column_txn,
+ )
+
+ if end:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN
+ )
+
+ return batch_size
+
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1350,7 +1416,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.config = hs.config
- async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion):
+ async def upsert_room_on_join(
+ self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
+ ):
"""Ensure that the room is stored in the table
Called when we join a room over federation, and overwrites any room version
@@ -1361,6 +1429,24 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# mark the room as having an auth chain cover index.
has_auth_chain_index = await self.has_auth_chain_index(room_id)
+ create_event = None
+ for e in auth_events:
+ if (e.type, e.state_key) == (EventTypes.Create, ""):
+ create_event = e
+ break
+
+ if create_event is None:
+ # If the state doesn't have a create event then the room is
+ # invalid, and it would fail auth checks anyway.
+ raise StoreError(400, "No create event in state")
+
+ room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+
+ if not isinstance(room_creator, str):
+ # If the create event does not have a creator then the room is
+ # invalid, and it would fail auth checks anyway.
+ raise StoreError(400, "No creator defined on the create event")
+
await self.db_pool.simple_upsert(
desc="upsert_room_on_join",
table="rooms",
@@ -1368,7 +1454,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
values={"room_version": room_version.identifier},
insertion_values={
"is_public": False,
- "creator": "",
+ "creator": room_creator,
"has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an
@@ -1396,6 +1482,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
insertion_values={
"room_version": room_version.identifier,
"is_public": False,
+ # We don't worry about setting the `creator` here because
+ # we don't process any messages in a room while a user is
+ # invited (only after the join).
"creator": "",
"has_auth_chain_index": has_auth_chain_index,
},
diff --git a/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql
new file mode 100644
index 0000000000..f7c0b31261
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json)
+ VALUES (6302, 'populate_rooms_creator_column', '{}');
diff --git a/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql
new file mode 100644
index 0000000000..b90856004b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2021 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6304, 'presence_stream_not_offline_index', '{}');
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index c768fdea56..6f7cbe40f4 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -19,6 +19,7 @@ from contextlib import contextmanager
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import attr
+from sortedcontainers import SortedSet
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -240,7 +241,7 @@ class MultiWriterIdGenerator:
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
- self._unfinished_ids: Set[int] = set()
+ self._unfinished_ids: SortedSet[int] = SortedSet()
# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
@@ -473,7 +474,7 @@ class MultiWriterIdGenerator:
finished = set()
- min_unfinshed = min(self._unfinished_ids)
+ min_unfinshed = self._unfinished_ids[0]
for s in self._finished_ids:
if s < min_unfinshed:
if new_cur is None or new_cur < s:
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 522daa323d..cfb5b94ca9 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -61,7 +61,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
-----END RSA PRIVATE KEY-----"""
-def manhole(username, password, globals):
+def manhole(settings, globals):
"""Starts a ssh listener with password authentication using
the given username and password. Clients connecting to the ssh
listener will find themselves in a colored python shell with
@@ -75,6 +75,15 @@ def manhole(username, password, globals):
Returns:
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
"""
+ username = settings.username
+ password = settings.password
+ priv_key = settings.priv_key
+ if priv_key is None:
+ priv_key = Key.fromString(PRIVATE_KEY)
+ pub_key = settings.pub_key
+ if pub_key is None:
+ pub_key = Key.fromString(PUBLIC_KEY)
+
if not isinstance(password, bytes):
password = password.encode("ascii")
@@ -86,8 +95,8 @@ def manhole(username, password, globals):
)
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
- factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY)
- factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY)
+ factory.privateKeys[b"ssh-rsa"] = priv_key
+ factory.publicKeys[b"ssh-rsa"] = pub_key
return factory
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index 6f2b9e997d..b6f21294ba 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -35,7 +35,7 @@ class ServerConfigTestCase(unittest.TestCase):
def test_unsecure_listener_no_listeners_open_private_ports_false(self):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "che.org", "/data_dir_path", False, None
+ "che.org", "/data_dir_path", False, None, config_dir_path="CONFDIR"
)
)
@@ -55,7 +55,7 @@ class ServerConfigTestCase(unittest.TestCase):
def test_unsecure_listener_no_listeners_open_private_ports_true(self):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "che.org", "/data_dir_path", True, None
+ "che.org", "/data_dir_path", True, None, config_dir_path="CONFDIR"
)
)
@@ -89,7 +89,7 @@ class ServerConfigTestCase(unittest.TestCase):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "this.one.listens", "/data_dir_path", True, listeners
+ "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR"
)
)
@@ -123,7 +123,7 @@ class ServerConfigTestCase(unittest.TestCase):
conf = yaml.safe_load(
ServerConfig().generate_config_section(
- "this.one.listens", "/data_dir_path", True, listeners
+ "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR"
)
)
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
new file mode 100644
index 0000000000..fcde5dab72
--- /dev/null
+++ b/tests/handlers/test_room.py
@@ -0,0 +1,108 @@
+import synapse
+from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms
+from synapse.rest.client import login, room
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ ]
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "all"})
+ def test_encrypted_by_default_config_option_all(self):
+ """Tests that invite-only and non-invite-only rooms have encryption enabled by
+ default when the config option encryption_enabled_by_default_for_room_type is "all".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
+ def test_encrypted_by_default_config_option_invite(self):
+ """Tests that only new, invite-only rooms have encryption enabled by default when
+ the config option encryption_enabled_by_default_for_room_type is "invite".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room has an encryption state event
+ event_content = self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ )
+ self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ @override_config({"encryption_enabled_by_default_for_room_type": "off"})
+ def test_encrypted_by_default_config_option_off(self):
+ """Tests that neither new invite-only nor non-invite-only rooms have encryption
+ enabled by default when the config option
+ encryption_enabled_by_default_for_room_type is "off".
+ """
+ # Create a user
+ user = self.register_user("user", "pass")
+ user_token = self.login(user, "pass")
+
+ # Create an invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
+
+ # Create a non invite-only room as that user
+ room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+ # Check that the room does not have an encryption state event
+ self.helper.get_state(
+ room_id=room_id,
+ event_type=EventTypes.RoomEncryption,
+ tok=user_token,
+ expect_code=404,
+ )
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index ac800afa7d..d3d0bf1ac5 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -35,10 +35,11 @@ from synapse.types import JsonDict, UserID
from tests import unittest
-def _create_event(room_id: str, order: Optional[Any] = None):
- result = mock.Mock()
+def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0):
+ result = mock.Mock(name=room_id)
result.room_id = room_id
result.content = {}
+ result.origin_server_ts = origin_server_ts
if order is not None:
result.content["order"] = order
return result
@@ -63,10 +64,17 @@ class TestSpaceSummarySort(unittest.TestCase):
self.assertEqual([ev2, ev1], _order(ev1, ev2))
+ def test_order_origin_server_ts(self):
+ """Origin server is a tie-breaker for ordering."""
+ ev1 = _create_event("!abc:test", origin_server_ts=10)
+ ev2 = _create_event("!xyz:test", origin_server_ts=30)
+
+ self.assertEqual([ev1, ev2], _order(ev1, ev2))
+
def test_order_room_id(self):
- """Room ID is a tie-breaker for ordering."""
- ev1 = _create_event("!abc:test", "abc")
- ev2 = _create_event("!xyz:test", "abc")
+ """Room ID is a final tie-breaker for ordering."""
+ ev1 = _create_event("!abc:test")
+ ev2 = _create_event("!xyz:test")
self.assertEqual([ev1, ev2], _order(ev1, ev2))
@@ -573,6 +581,31 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
]
self._assert_hierarchy(result, expected)
+ def test_unknown_room_version(self):
+ """
+ If an room with an unknown room version is encountered it should not cause
+ the entire summary to skip.
+ """
+ # Poke the database and update the room version to an unknown one.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_update(
+ "rooms",
+ keyvalues={"room_id": self.room},
+ updatevalues={"room_version": "unknown-room-version"},
+ desc="updated-room-version",
+ )
+ )
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+ # The result should have only the space, along with a link from space -> room.
+ expected = [(self.space, [self.room])]
+ self._assert_rooms(result, expected)
+
+ result = self.get_success(
+ self.handler.get_room_hierarchy(self.user, self.space)
+ )
+ self._assert_hierarchy(result, expected)
+
def test_fed_complex(self):
"""
Return data over federation and ensure that it is handled properly.
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index e44bf2b3b1..a91d31ce61 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -16,7 +16,7 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
+from synapse.api.constants import UserTypes
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.rest.client import login, room, user_directory
from synapse.storage.roommember import ProfileInfo
@@ -187,100 +187,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
- @override_config({"encryption_enabled_by_default_for_room_type": "all"})
- def test_encrypted_by_default_config_option_all(self):
- """Tests that invite-only and non-invite-only rooms have encryption enabled by
- default when the config option encryption_enabled_by_default_for_room_type is "all".
- """
- # Create a user
- user = self.register_user("user", "pass")
- user_token = self.login(user, "pass")
-
- # Create an invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
-
- # Check that the room has an encryption state event
- event_content = self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- )
- self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
-
- # Create a non invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
-
- # Check that the room has an encryption state event
- event_content = self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- )
- self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
-
- @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
- def test_encrypted_by_default_config_option_invite(self):
- """Tests that only new, invite-only rooms have encryption enabled by default when
- the config option encryption_enabled_by_default_for_room_type is "invite".
- """
- # Create a user
- user = self.register_user("user", "pass")
- user_token = self.login(user, "pass")
-
- # Create an invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
-
- # Check that the room has an encryption state event
- event_content = self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- )
- self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
-
- # Create a non invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
-
- # Check that the room does not have an encryption state event
- self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- expect_code=404,
- )
-
- @override_config({"encryption_enabled_by_default_for_room_type": "off"})
- def test_encrypted_by_default_config_option_off(self):
- """Tests that neither new invite-only nor non-invite-only rooms have encryption
- enabled by default when the config option
- encryption_enabled_by_default_for_room_type is "off".
- """
- # Create a user
- user = self.register_user("user", "pass")
- user_token = self.login(user, "pass")
-
- # Create an invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
-
- # Check that the room does not have an encryption state event
- self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- expect_code=404,
- )
-
- # Create a non invite-only room as that user
- room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
-
- # Check that the room does not have an encryption state event
- self.helper.get_state(
- room_id=room_id,
- event_type=EventTypes.RoomEncryption,
- tok=user_token,
- expect_code=404,
- )
-
def test_spam_checker(self):
"""
A user which fails the spam checks will not appear in search results.
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index c4ba13a6b2..fa8018e5a7 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import email.message
import os
+from typing import Dict, List, Sequence, Tuple
import attr
import pkg_resources
@@ -70,9 +71,10 @@ class EmailPusherTests(HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
# List[Tuple[Deferred, args, kwargs]]
- self.email_attempts = []
+ self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = []
def sendmail(*args, **kwargs):
+ # This mocks out synapse.reactor.send_email._sendmail.
d = Deferred()
self.email_attempts.append((d, args, kwargs))
return d
@@ -255,6 +257,39 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_room_notifications_include_avatar(self):
+ # Create a room and set its avatar.
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.send_state(
+ room, "m.room.avatar", {"url": "mxc://DUMMY_MEDIA_ID"}, self.access_token
+ )
+
+ # Invite two other uses.
+ for other in self.others:
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=other.id
+ )
+ self.helper.join(room=room, user=other.id, tok=other.token)
+
+ # The other users send some messages.
+ # TODO It seems that two messages are required to trigger an email?
+ self.helper.send(room, body="Alpha", tok=self.others[0].token)
+ self.helper.send(room, body="Beta", tok=self.others[1].token)
+
+ # We should get emailed about those messages
+ args, kwargs = self._check_for_mail()
+
+ # That email should contain the room's avatar
+ msg: bytes = args[5]
+ # Multipart: plain text, base 64 encoded; html, base 64 encoded
+ html = (
+ email.message_from_bytes(msg)
+ .get_payload()[1]
+ .get_payload(decode=True)
+ .decode()
+ )
+ self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html)
+
def test_empty_room(self):
"""All users leaving a room shouldn't cause the pusher to break."""
# Create a simple room with two users
@@ -388,9 +423,14 @@ class EmailPusherTests(HomeserverTestCase):
pushers = list(pushers)
self.assertEqual(len(pushers), 0)
- def _check_for_mail(self):
- """Check that the user receives an email notification"""
+ def _check_for_mail(self) -> Tuple[Sequence, Dict]:
+ """
+ Assert that synapse sent off exactly one email notification.
+ Returns:
+ args and kwargs passed to synapse.reactor.send_email._sendmail for
+ that notification.
+ """
# Get the stream ordering before it gets sent
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
@@ -413,8 +453,9 @@ class EmailPusherTests(HomeserverTestCase):
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
+ deferred, sendmail_args, sendmail_kwargs = self.email_attempts[0]
# Make the email succeed
- self.email_attempts[0][0].callback(True)
+ deferred.callback(True)
self.pump()
# One email was attempted to be sent
@@ -430,3 +471,4 @@ class EmailPusherTests(HomeserverTestCase):
# Reset the attempts.
self.email_attempts = []
+ return sendmail_args, sendmail_kwargs
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index d3ef7bb4c6..7fa9027227 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -14,13 +14,14 @@
import json
import os
import re
-from unittest.mock import patch
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
+from synapse.config.oembed import OEmbedEndpointConfig
+
from tests import unittest
from tests.server import FakeTransport
@@ -81,6 +82,19 @@ class URLPreviewTests(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
+ # After the hs is created, modify the parsed oEmbed config (to avoid
+ # messing with files).
+ #
+ # Note that HTTP URLs are used to avoid having to deal with TLS in tests.
+ hs.config.oembed.oembed_patterns = [
+ OEmbedEndpointConfig(
+ api_endpoint="http://publish.twitter.com/oembed",
+ url_patterns=[
+ re.compile(r"http://twitter\.com/.+/status/.+"),
+ ],
+ )
+ ]
+
return hs
def prepare(self, reactor, clock, hs):
@@ -544,123 +558,101 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def test_oembed_photo(self):
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
- # Route the HTTP version to an HTTP endpoint so that the tests work.
- with patch.dict(
- "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
- {
- re.compile(
- r"http://twitter\.com/.+/status/.+"
- ): "http://publish.twitter.com/oembed",
- },
- clear=True,
- ):
-
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "photo",
- "url": "http://cdn.twitter.com/matrixdotorg",
- }
- oembed_content = json.dumps(result).encode("utf-8")
-
- end_content = (
- b"<html><head>"
- b"<title>Some Title</title>"
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
- channel = self.make_request(
- "GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(oembed_content),)
- + oembed_content
- )
+ result = {
+ "version": "1.0",
+ "type": "photo",
+ "url": "http://cdn.twitter.com/matrixdotorg",
+ }
+ oembed_content = json.dumps(result).encode("utf-8")
- self.pump()
-
- client = self.reactor.tcpClients[1][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
+ end_content = (
+ b"<html><head>"
+ b"<title>Some Title</title>"
+ b'<meta property="og:description" content="hi" />'
+ b"</head></html>"
+ )
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
)
+ % (len(oembed_content),)
+ + oembed_content
+ )
- self.pump()
+ self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ client = self.reactor.tcpClients[1][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: text/html; charset="utf8"\r\n\r\n'
)
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+ )
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
- # Route the HTTP version to an HTTP endpoint so that the tests work.
- with patch.dict(
- "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
- {
- re.compile(
- r"http://twitter\.com/.+/status/.+"
- ): "http://publish.twitter.com/oembed",
- },
- clear=True,
- ):
-
- self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
-
- result = {
- "version": "1.0",
- "type": "rich",
- "html": "<div>Content Preview</div>",
- }
- end_content = json.dumps(result).encode("utf-8")
-
- channel = self.make_request(
- "GET",
- "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
- shorthand=False,
- await_result=False,
- )
- self.pump()
-
- client = self.reactor.tcpClients[0][2].buildProtocol(None)
- server = AccumulatingProtocol()
- server.makeConnection(FakeTransport(client, self.reactor))
- client.makeConnection(FakeTransport(server, self.reactor))
- client.dataReceived(
- (
- b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: application/json; charset="utf8"\r\n\r\n'
- )
- % (len(end_content),)
- + end_content
- )
+ self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ result = {
+ "version": "1.0",
+ "type": "rich",
+ "html": "<div>Content Preview</div>",
+ }
+ end_content = json.dumps(result).encode("utf-8")
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
- self.pump()
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body,
- {"og:title": None, "og:description": "Content Preview"},
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ (
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+ b'Content-Type: application/json; charset="utf8"\r\n\r\n'
)
+ % (len(end_content),)
+ + end_content
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {"og:title": None, "og:description": "Content Preview"},
+ )
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
new file mode 100644
index 0000000000..ffee707153
--- /dev/null
+++ b/tests/storage/databases/main/test_room.py
@@ -0,0 +1,98 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.storage.databases.main.room import _BackgroundUpdates
+
+from tests.unittest import HomeserverTestCase
+
+
+class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.user_id = self.register_user("foo", "pass")
+ self.token = self.login("foo", "pass")
+
+ def _generate_room(self) -> str:
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ return room_id
+
+ def test_background_populate_rooms_creator_column(self):
+ """Test that the background update to populate the rooms creator column
+ works properly.
+ """
+
+ # Insert a room without the creator
+ room_id = self._generate_room()
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"creator": None},
+ desc="test",
+ )
+ )
+
+ # Make sure the test is starting out with a room without a creator
+ room_creator_before = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="creator",
+ allow_none=True,
+ )
+ )
+ self.assertEqual(room_creator_before, None)
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Make sure the background update filled in the room creator
+ room_creator_after = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="creator",
+ allow_none=True,
+ )
+ )
+ self.assertEqual(room_creator_after, self.user_id)
|