diff --git a/INSTALL.md b/INSTALL.md
index b88d826f6c..1934593148 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -35,7 +35,7 @@ virtualenv -p python3 ~/synapse/env
source ~/synapse/env/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools
-pip install matrix-synapse[all]
+pip install matrix-synapse
```
This will download Synapse from [PyPI](https://pypi.org/project/matrix-synapse)
@@ -48,7 +48,7 @@ update flag:
```
source ~/synapse/env/bin/activate
-pip install -U matrix-synapse[all]
+pip install -U matrix-synapse
```
Before you can start Synapse, you will need to generate a configuration
diff --git a/changelog.d/3484.misc b/changelog.d/3484.misc
new file mode 100644
index 0000000000..3645849844
--- /dev/null
+++ b/changelog.d/3484.misc
@@ -0,0 +1 @@
+Make /sync attempt to return device updates for both joined and invited users. Note that this doesn't currently work correctly due to other bugs.
diff --git a/changelog.d/5039.bugfix b/changelog.d/5039.bugfix
new file mode 100644
index 0000000000..212cff7ae8
--- /dev/null
+++ b/changelog.d/5039.bugfix
@@ -0,0 +1 @@
+Fix image orientation when generating thumbnails (needs pillow>=4.3.0). Contributed by Pau Rodriguez-Estivill.
diff --git a/changelog.d/5174.bugfix b/changelog.d/5174.bugfix
new file mode 100644
index 0000000000..0f26d46b2c
--- /dev/null
+++ b/changelog.d/5174.bugfix
@@ -0,0 +1 @@
+Re-order stages in registration flows such that msisdn and email verification are done last.
diff --git a/changelog.d/5177.bugfix b/changelog.d/5177.bugfix
new file mode 100644
index 0000000000..c2f1644ae5
--- /dev/null
+++ b/changelog.d/5177.bugfix
@@ -0,0 +1 @@
+Fix 3pid guest invites.
diff --git a/changelog.d/5187.bugfix b/changelog.d/5187.bugfix
new file mode 100644
index 0000000000..df176cf5b2
--- /dev/null
+++ b/changelog.d/5187.bugfix
@@ -0,0 +1 @@
+Fix a bug where the register endpoint would fail with M_THREEPID_IN_USE instead of returning an account previously registered in the same session.
diff --git a/changelog.d/5191.misc b/changelog.d/5191.misc
new file mode 100644
index 0000000000..e0615fec9c
--- /dev/null
+++ b/changelog.d/5191.misc
@@ -0,0 +1 @@
+Make generating SQL bounds for pagination generic.
diff --git a/changelog.d/5196.feature b/changelog.d/5196.feature
new file mode 100644
index 0000000000..1ffb928f62
--- /dev/null
+++ b/changelog.d/5196.feature
@@ -0,0 +1 @@
+Add an option to disable per-room profiles.
diff --git a/changelog.d/5197.misc b/changelog.d/5197.misc
new file mode 100644
index 0000000000..fca1d86b2e
--- /dev/null
+++ b/changelog.d/5197.misc
@@ -0,0 +1 @@
+Stop telling people to install the optional dependencies by default.
diff --git a/changelog.d/5198.bugfix b/changelog.d/5198.bugfix
new file mode 100644
index 0000000000..c6b156f17d
--- /dev/null
+++ b/changelog.d/5198.bugfix
@@ -0,0 +1 @@
+Prevent registration for user ids that are to long to fit into a state key. Contributed by Reid Anderson.
\ No newline at end of file
diff --git a/changelog.d/5209.feature b/changelog.d/5209.feature
new file mode 100644
index 0000000000..747098c166
--- /dev/null
+++ b/changelog.d/5209.feature
@@ -0,0 +1 @@
+Add experimental support for relations (aka reactions and edits).
diff --git a/changelog.d/5210.feature b/changelog.d/5210.feature
new file mode 100644
index 0000000000..a7476bf9b9
--- /dev/null
+++ b/changelog.d/5210.feature
@@ -0,0 +1 @@
+Add a new room version which uses a new event ID format.
diff --git a/changelog.d/5211.feature b/changelog.d/5211.feature
new file mode 100644
index 0000000000..747098c166
--- /dev/null
+++ b/changelog.d/5211.feature
@@ -0,0 +1 @@
+Add experimental support for relations (aka reactions and edits).
diff --git a/debian/test/.gitignore b/debian/test/.gitignore
new file mode 100644
index 0000000000..95eda73fcc
--- /dev/null
+++ b/debian/test/.gitignore
@@ -0,0 +1,2 @@
+.vagrant
+*.log
diff --git a/debian/test/provision.sh b/debian/test/provision.sh
new file mode 100644
index 0000000000..a5c7f59712
--- /dev/null
+++ b/debian/test/provision.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+#
+# provisioning script for vagrant boxes for testing the matrix-synapse debs.
+#
+# Will install the most recent matrix-synapse-py3 deb for this platform from
+# the /debs directory.
+
+set -e
+
+apt-get update
+apt-get install -y lsb-release
+
+deb=`ls /debs/matrix-synapse-py3_*+$(lsb_release -cs)*.deb | sort | tail -n1`
+
+debconf-set-selections <<EOF
+matrix-synapse matrix-synapse/report-stats boolean false
+matrix-synapse matrix-synapse/server-name string localhost:18448
+EOF
+
+dpkg -i "$deb"
+
+sed -i -e '/port: 8...$/{s/8448/18448/; s/8008/18008/}' -e '$aregistration_shared_secret: secret' /etc/matrix-synapse/homeserver.yaml
+systemctl restart matrix-synapse
diff --git a/debian/test/stretch/Vagrantfile b/debian/test/stretch/Vagrantfile
new file mode 100644
index 0000000000..d8eff6fe11
--- /dev/null
+++ b/debian/test/stretch/Vagrantfile
@@ -0,0 +1,13 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+
+ver = `cd ../../..; dpkg-parsechangelog -S Version`.strip()
+
+Vagrant.configure("2") do |config|
+ config.vm.box = "debian/stretch64"
+
+ config.vm.synced_folder ".", "/vagrant", disabled: true
+ config.vm.synced_folder "../../../../debs", "/debs", type: "nfs"
+
+ config.vm.provision "shell", path: "../provision.sh"
+end
diff --git a/debian/test/xenial/Vagrantfile b/debian/test/xenial/Vagrantfile
new file mode 100644
index 0000000000..189236da17
--- /dev/null
+++ b/debian/test/xenial/Vagrantfile
@@ -0,0 +1,10 @@
+# -*- mode: ruby -*-
+# vi: set ft=ruby :
+
+Vagrant.configure("2") do |config|
+ config.vm.box = "ubuntu/xenial64"
+
+ config.vm.synced_folder ".", "/vagrant", disabled: true
+ config.vm.synced_folder "../../../../debs", "/debs"
+ config.vm.provision "shell", path: "../provision.sh"
+end
diff --git a/docs/postgres.rst b/docs/postgres.rst
index f7ebbed0c3..e81e10403f 100644
--- a/docs/postgres.rst
+++ b/docs/postgres.rst
@@ -3,6 +3,28 @@ Using Postgres
Postgres version 9.4 or later is known to work.
+Install postgres client libraries
+=================================
+
+Synapse will require the python postgres client library in order to connect to
+a postgres database.
+
+* If you are using the `matrix.org debian/ubuntu
+ packages <../INSTALL.md#matrixorg-packages>`_,
+ the necessary libraries will already be installed.
+
+* For other pre-built packages, please consult the documentation from the
+ relevant package.
+
+* If you installed synapse `in a virtualenv
+ <../INSTALL.md#installing-from-source>`_, you can install the library with::
+
+ ~/synapse/env/bin/pip install matrix-synapse[postgres]
+
+ (substituting the path to your virtualenv for ``~/synapse/env``, if you used a
+ different path). You will require the postgres development files. These are in
+ the ``libpq-dev`` package on Debian-derived distributions.
+
Set up database
===============
@@ -26,29 +48,6 @@ encoding use, e.g.::
This would create an appropriate database named ``synapse`` owned by the
``synapse_user`` user (which must already exist).
-Set up client in Debian/Ubuntu
-===========================
-
-Postgres support depends on the postgres python connector ``psycopg2``. In the
-virtual env::
-
- sudo apt-get install libpq-dev
- pip install psycopg2
-
-Set up client in RHEL/CentOs 7
-==============================
-
-Make sure you have the appropriate version of postgres-devel installed. For a
-postgres 9.4, use the postgres 9.4 packages from
-[here](https://wiki.postgresql.org/wiki/YUM_Installation).
-
-As with Debian/Ubuntu, postgres support depends on the postgres python connector
-``psycopg2``. In the virtual env::
-
- sudo yum install postgresql-devel libpqxx-devel.x86_64
- export PATH=/usr/pgsql-9.4/bin/:$PATH
- pip install psycopg2
-
Tuning Postgres
===============
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 09ee0e8984..d218aefee5 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -276,6 +276,12 @@ listeners:
#
#require_membership_for_aliases: false
+# Whether to allow per-room membership profiles through the send of membership
+# events with profile information that differ from the target's global profile.
+# Defaults to 'true'.
+#
+#allow_per_room_profiles: false
+
## TLS ##
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8547a63535..6b347b1749 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -23,6 +23,9 @@ MAX_DEPTH = 2**63 - 1
# the maximum length for a room alias is 255 characters
MAX_ALIAS_LENGTH = 255
+# the maximum length for a user id is 255 characters
+MAX_USERID_LENGTH = 255
+
class Membership(object):
@@ -116,3 +119,11 @@ class UserTypes(object):
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)
+
+
+class RelationTypes(object):
+ """The types of relations known to this server.
+ """
+ ANNOTATION = "m.annotation"
+ REPLACE = "m.replace"
+ REFERENCE = "m.reference"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index e77abe1040..485b3d0237 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -19,13 +19,15 @@ class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
- V1 = 1 # $id:server format
- V2 = 2 # MSC1659-style $hash format: introduced for room v3
+ V1 = 1 # $id:server event id format
+ V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
+ V3 = 3 # MSC1884-style $hash format: introduced for room v4
KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.V1,
EventFormatVersions.V2,
+ EventFormatVersions.V3,
}
@@ -75,6 +77,12 @@ class RoomVersions(object):
EventFormatVersions.V2,
StateResolutionVersions.V2,
)
+ EVENTID_NOSLASH_TEST = RoomVersion(
+ "eventid-noslash-test",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ )
# the version we will give rooms which are created on this server
@@ -87,5 +95,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
+ RoomVersions.EVENTID_NOSLASH_TEST,
)
} # type: dict[str, RoomVersion]
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7874cd9da7..f34aa42afa 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -100,6 +101,11 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
+ # Whether to enable experimental MSC1849 (aka relations) support
+ self.experimental_msc1849_support_enabled = config.get(
+ "experimental_msc1849_support_enabled", False,
+ )
+
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
@@ -173,6 +179,10 @@ class ServerConfig(Config):
"require_membership_for_aliases", True,
)
+ # Whether to allow per-room membership profiles through the send of membership
+ # events with profile information that differ from the target's global profile.
+ self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+
self.listeners = []
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@@ -566,6 +576,12 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#require_membership_for_aliases: false
+
+ # Whether to allow per-room membership profiles through the send of membership
+ # events with profile information that differ from the target's global profile.
+ # Defaults to 'true'.
+ #
+ #allow_per_room_profiles: false
""" % locals()
def read_arguments(self, args):
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 12056d5be2..badeb903fc 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -335,13 +335,32 @@ class FrozenEventV2(EventBase):
return self.__repr__()
def __repr__(self):
- return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
+ return "<%s event_id='%s', type='%s', state_key='%s'>" % (
+ self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
+class FrozenEventV3(FrozenEventV2):
+ """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
+ format_version = EventFormatVersions.V3 # All events of this type are V3
+
+ @property
+ def event_id(self):
+ # We have to import this here as otherwise we get an import loop which
+ # is hard to break.
+ from synapse.crypto.event_signing import compute_event_reference_hash
+
+ if self._event_id:
+ return self._event_id
+ self._event_id = "$" + encode_base64(
+ compute_event_reference_hash(self)[1], urlsafe=True
+ )
+ return self._event_id
+
+
def room_version_to_event_format(room_version):
"""Converts a room version string to the event format
@@ -376,6 +395,8 @@ def event_type_from_format_version(format_version):
return FrozenEvent
elif format_version == EventFormatVersions.V2:
return FrozenEventV2
+ elif format_version == EventFormatVersions.V3:
+ return FrozenEventV3
else:
raise Exception(
"No event format %r" % (format_version,)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a5454556cc..27a2a9ef98 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -21,7 +21,7 @@ from frozendict import frozendict
from twisted.internet import defer
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.util.async_helpers import yieldable_gather_results
from . import EventBase
@@ -324,8 +324,12 @@ class EventClientSerializer(object):
"""
def __init__(self, hs):
- pass
+ self.store = hs.get_datastore()
+ self.experimental_msc1849_support_enabled = (
+ hs.config.experimental_msc1849_support_enabled
+ )
+ @defer.inlineCallbacks
def serialize_event(self, event, time_now, **kwargs):
"""Serializes a single event.
@@ -337,8 +341,52 @@ class EventClientSerializer(object):
Returns:
Deferred[dict]: The serialized event
"""
- event = serialize_event(event, time_now, **kwargs)
- return defer.succeed(event)
+ # To handle the case of presence events and the like
+ if not isinstance(event, EventBase):
+ defer.returnValue(event)
+
+ event_id = event.event_id
+ serialized_event = serialize_event(event, time_now, **kwargs)
+
+ # If MSC1849 is enabled then we need to look if thre are any relations
+ # we need to bundle in with the event
+ if self.experimental_msc1849_support_enabled:
+ annotations = yield self.store.get_aggregation_groups_for_event(
+ event_id,
+ )
+ references = yield self.store.get_relations_for_event(
+ event_id, RelationTypes.REFERENCE, direction="f",
+ )
+
+ if annotations.chunk:
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ if references.chunk:
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = yield self.store.get_applicable_edit(event_id)
+
+ if edit:
+ # If there is an edit replace the content, preserving existing
+ # relations.
+
+ relations = event.content.get("m.relates_to")
+ serialized_event["content"] = edit.content.get("m.new_content", {})
+ if relations:
+ serialized_event["content"]["m.relates_to"] = relations
+ else:
+ serialized_event["content"].pop("m.relates_to", None)
+
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.REPLACE] = {
+ "event_id": edit.event_id,
+ }
+
+ defer.returnValue(serialized_event)
def serialize_events(self, events, time_now, **kwargs):
"""Serializes multiple events.
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a51d11a257..e83ee24f10 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -19,7 +19,7 @@ import logging
from twisted.internet import defer
from synapse import types
-from synapse.api.constants import LoginType
+from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import (
AuthError,
Codes,
@@ -123,6 +123,15 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id)
+ if len(user_id) > MAX_USERID_LENGTH:
+ raise SynapseError(
+ 400,
+ "User ID may not be longer than %s characters" % (
+ MAX_USERID_LENGTH,
+ ),
+ Codes.INVALID_USERNAME
+ )
+
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 3e86b9c690..93ac986c86 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -73,6 +74,7 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self._server_notices_mxid = self.config.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
+ self.allow_per_room_profiles = self.config.allow_per_room_profiles
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
@@ -357,6 +359,13 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
+ if not self.allow_per_room_profiles:
+ # Strip profile data, knowing that new profile data will be added to the
+ # event's content in event_creation_handler.create_event() using the target's
+ # global profile.
+ content.pop("displayname", None)
+ content.pop("avatar_url", None)
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@@ -935,7 +944,7 @@ class RoomMemberHandler(object):
}
if self.config.invite_3pid_guest:
- guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+ guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest(
requester=requester,
medium=medium,
address=address,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 153312e39f..1ee9a6e313 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -934,7 +934,7 @@ class SyncHandler(object):
res = yield self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_users, _, _ = res
+ newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@@ -943,7 +943,7 @@ class SyncHandler(object):
)
if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_users
+ sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
)
yield self._generate_sync_entry_for_to_device(sync_result_builder)
@@ -951,7 +951,7 @@ class SyncHandler(object):
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
- newly_joined_users=newly_joined_users,
+ newly_joined_or_invited_users=newly_joined_or_invited_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@@ -1036,7 +1036,8 @@ class SyncHandler(object):
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder,
- newly_joined_rooms, newly_joined_users,
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
newly_left_rooms, newly_left_users):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1050,7 +1051,7 @@ class SyncHandler(object):
# share a room with?
for room_id in newly_joined_rooms:
joined_users = yield self.state.get_current_users_in_room(room_id)
- newly_joined_users.update(joined_users)
+ newly_joined_or_invited_users.update(joined_users)
for room_id in newly_left_rooms:
left_users = yield self.state.get_current_users_in_room(room_id)
@@ -1058,7 +1059,7 @@ class SyncHandler(object):
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- changed.update(newly_joined_users)
+ changed.update(newly_joined_or_invited_users)
if not changed and not newly_left_users:
defer.returnValue(DeviceLists(
@@ -1176,7 +1177,7 @@ class SyncHandler(object):
@defer.inlineCallbacks
def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
- newly_joined_users):
+ newly_joined_or_invited_users):
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1184,8 +1185,9 @@ class SyncHandler(object):
sync_result_builder(SyncResultBuilder)
newly_joined_rooms(list): List of rooms that the user has joined
since the last sync (or empty if an initial sync)
- newly_joined_users(list): List of users that have joined rooms
- since the last sync (or empty if an initial sync)
+ newly_joined_or_invited_users(list): List of users that have joined
+ or been invited to rooms since the last sync (or empty if an initial
+ sync)
"""
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@@ -1211,7 +1213,7 @@ class SyncHandler(object):
"presence_key", presence_key
)
- extra_users_ids = set(newly_joined_users)
+ extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
users = yield self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
@@ -1243,7 +1245,8 @@ class SyncHandler(object):
Returns:
Deferred(tuple): Returns a 4-tuple of
- `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
+ `(newly_joined_rooms, newly_joined_or_invited_users,
+ newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
@@ -1314,8 +1317,8 @@ class SyncHandler(object):
sync_result_builder.invited.extend(invited)
- # Now we want to get any newly joined users
- newly_joined_users = set()
+ # Now we want to get any newly joined or invited users
+ newly_joined_or_invited_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@@ -1324,19 +1327,22 @@ class SyncHandler(object):
)
for event in it:
if event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- newly_joined_users.add(event.state_key)
+ if (
+ event.membership == Membership.JOIN or
+ event.membership == Membership.INVITE
+ ):
+ newly_joined_or_invited_users.add(event.state_key)
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
- newly_left_users -= newly_joined_users
+ newly_left_users -= newly_joined_or_invited_users
defer.returnValue((
newly_joined_rooms,
- newly_joined_users,
+ newly_joined_or_invited_users,
newly_left_rooms,
newly_left_users,
))
@@ -1381,7 +1387,7 @@ class SyncHandler(object):
where:
room_entries is a list [RoomSyncResultBuilder]
invited_rooms is a list [InvitedSyncResult]
- newly_joined rooms is a list[str] of room ids
+ newly_joined_rooms is a list[str] of room ids
newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1422,7 +1428,7 @@ class SyncHandler(object):
if room_id in sync_result_builder.joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly.
- # If there are non join member events, but we are still in the room,
+ # If there are non-join member events, but we are still in the room,
# then the user must have left and joined
newly_joined_rooms.append(room_id)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 2708f5e820..fdcfb90a7e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -53,7 +53,7 @@ REQUIREMENTS = [
"pyasn1-modules>=0.0.7",
"daemonize>=2.3.1",
"bcrypt>=3.1.0",
- "pillow>=3.1.2",
+ "pillow>=4.3.0",
"sortedcontainers>=1.4.4",
"psutil>=2.0.0",
"pymacaroons>=0.13.0",
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b457c5563f..a3952506c1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.events_worker import EventsWorkerStore
+from synapse.storage.relations import RelationsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.storage.state import StateGroupWorkerStore
@@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore,
EventsWorkerStore,
SignatureWorkerStore,
UserErasureWorkerStore,
+ RelationsWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs):
@@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore,
for row in rows:
self.invalidate_caches_for_event(
-token, row.event_id, row.room_id, row.type, row.state_key,
- row.redacts,
+ row.redacts, row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
@@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore,
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
token, data.event_id, data.room_id, data.type, data.state_key,
- data.redacts,
+ data.redacts, data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
@@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore,
raise Exception("Unknown events stream row type %s" % (row.type, ))
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
- etype, state_key, redacts, backfilled):
+ etype, state_key, redacts, relates_to,
+ backfilled):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -136,3 +139,8 @@ class SlavedEventStore(EventFederationWorkerStore,
state_key, stream_ordering
)
self.get_invited_rooms_for_user.invalidate((state_key,))
+
+ if relates_to:
+ self.get_relations_for_event.invalidate_many((relates_to,))
+ self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 8971a6a22e..b6ce7a7bee 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", (
"type", # str
"state_key", # str, optional
"redacts", # str, optional
+ "relates_to", # str, optional
))
PresenceStreamRow = namedtuple("PresenceStreamRow", (
"user_id", # str
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index e0f6e29248..f1290d022a 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -80,11 +80,12 @@ class BaseEventsStreamRow(object):
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
+ event_id = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str, optional
+ redacts = attr.ib() # str, optional
+ relates_to = attr.ib() # str, optional
@attr.s(slots=True, frozen=True)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3a24d31d1b..e6110ad9b1 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import (
read_marker,
receipts,
register,
+ relations,
report_event,
room_keys,
room_upgrade_rest_servlet,
@@ -115,6 +116,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ relations.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 3d045880b9..042f636135 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -348,18 +348,22 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.enable_registration_captcha:
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
- flows.extend([[LoginType.RECAPTCHA]])
+ # Also add a dummy flow here, otherwise if a client completes
+ # recaptcha first we'll assume they were going for this flow
+ # and complete the request, when they could have been trying to
+ # complete one of the flows with email/msisdn auth.
+ flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
- flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+ flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
- flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
+ [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
else:
# only support 3PIDless registration if no 3PIDs are required
@@ -382,7 +386,15 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.user_consent_at_registration:
new_flows = []
for flow in flows:
- flow.append(LoginType.TERMS)
+ inserted = False
+ # m.login.terms should go near the end but before msisdn or email auth
+ for i, stage in enumerate(flow):
+ if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
+ flow.insert(i, LoginType.TERMS)
+ inserted = True
+ break
+ if not inserted:
+ flow.append(LoginType.TERMS)
flows.extend(new_flows)
auth_result, params, session_id = yield self.auth_handler.check_auth(
@@ -394,13 +406,6 @@ class RegisterRestServlet(RestServlet):
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
- #
- # Also check that we're not trying to register a 3pid that's already
- # been registered.
- #
- # This has probably happened in /register/email/requestToken as well,
- # but if a user hits this endpoint twice then clicks on each link from
- # the two activation emails, they would register the same 3pid twice.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
@@ -416,17 +421,6 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existingUid = yield self.store.get_user_id_by_threepid(
- medium, address,
- )
-
- if existingUid is not None:
- raise SynapseError(
- 400,
- "%s is already in use" % medium,
- Codes.THREEPID_IN_USE,
- )
-
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
@@ -449,6 +443,28 @@ class RegisterRestServlet(RestServlet):
if auth_result:
threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
+ # Also check that we're not trying to register a 3pid that's already
+ # been registered.
+ #
+ # This has probably happened in /register/email/requestToken as well,
+ # but if a user hits this endpoint twice then clicks on each link from
+ # the two activation emails, they would register the same 3pid twice.
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ existingUid = yield self.store.get_user_id_by_threepid(
+ medium, address,
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400,
+ "%s is already in use" % medium,
+ Codes.THREEPID_IN_USE,
+ )
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
new file mode 100644
index 0000000000..41e0a44936
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -0,0 +1,338 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This class implements the proposed relation APIs from MSC 1849.
+
+Since the MSC has not been approved all APIs here are unstable and may change at
+any time to reflect changes in the MSC.
+"""
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
+
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class RelationSendServlet(RestServlet):
+ """Helper API for sending events that have relation data.
+
+ Example API shape to send a 👍 reaction to a room:
+
+ POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D
+ {}
+
+ {
+ "event_id": "$foobar"
+ }
+ """
+
+ PATTERN = (
+ "/rooms/(?P<room_id>[^/]*)/send_relation"
+ "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ super(RelationSendServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.txns = HttpTransactionCache(hs)
+
+ def register(self, http_server):
+ http_server.register_paths(
+ "POST",
+ client_v2_patterns(self.PATTERN + "$", releases=()),
+ self.on_PUT_or_POST,
+ )
+ http_server.register_paths(
+ "PUT",
+ client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
+ self.on_PUT,
+ )
+
+ def on_PUT(self, request, *args, **kwargs):
+ return self.txns.fetch_or_execute_request(
+ request, self.on_PUT_or_POST, request, *args, **kwargs
+ )
+
+ @defer.inlineCallbacks
+ def on_PUT_or_POST(
+ self, request, room_id, parent_id, relation_type, event_type, txn_id=None
+ ):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ if event_type == EventTypes.Member:
+ # Add relations to a membership is meaningless, so we just deny it
+ # at the CS API rather than trying to handle it correctly.
+ raise SynapseError(400, "Cannot send member events with relations")
+
+ content = parse_json_object_from_request(request)
+
+ aggregation_key = parse_string(request, "key", encoding="utf-8")
+
+ content["m.relates_to"] = {
+ "event_id": parent_id,
+ "key": aggregation_key,
+ "rel_type": relation_type,
+ }
+
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict=event_dict, txn_id=txn_id
+ )
+
+ defer.returnValue((200, {"event_id": event.event_id}))
+
+
+class RelationPaginationServlet(RestServlet):
+ """API to paginate relations on an event by topological ordering, optionally
+ filtered by relation type and event type.
+ """
+
+ PATTERNS = client_v2_patterns(
+ "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
+ "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ result = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = yield self.store.get_events_as_list(
+ [c["event_id"] for c in result.chunk]
+ )
+
+ now = self.clock.time_msec()
+ events = yield self._event_serializer.serialize_events(events, now)
+
+ return_value = result.to_dict()
+ return_value["chunk"] = events
+
+ defer.returnValue((200, return_value))
+
+
+class RelationAggregationPaginationServlet(RestServlet):
+ """API to paginate aggregation groups of relations, e.g. paginate the
+ types and counts of the reactions on the events.
+
+ Example request and response:
+
+ GET /rooms/{room_id}/aggregations/{parent_id}
+
+ {
+ chunk: [
+ {
+ "type": "m.reaction",
+ "key": "👍",
+ "count": 3
+ }
+ ]
+ }
+ """
+
+ PATTERNS = client_v2_patterns(
+ "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
+ "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationAggregationPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ if relation_type not in (RelationTypes.ANNOTATION, None):
+ raise SynapseError(400, "Relation type must be 'annotation'")
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = AggregationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = AggregationPaginationToken.from_string(to_token)
+
+ res = yield self.store.get_aggregation_groups_for_event(
+ event_id=parent_id,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ defer.returnValue((200, res.to_dict()))
+
+
+class RelationAggregationGroupPaginationServlet(RestServlet):
+ """API to paginate within an aggregation group of relations, e.g. paginate
+ all the 👍 reactions on an event.
+
+ Example request and response:
+
+ GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍
+
+ {
+ chunk: [
+ {
+ "type": "m.reaction",
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.annotation",
+ "key": "👍"
+ }
+ }
+ },
+ ...
+ ]
+ }
+ """
+
+ PATTERNS = client_v2_patterns(
+ "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
+ "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationAggregationGroupPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ if relation_type != RelationTypes.ANNOTATION:
+ raise SynapseError(400, "Relation type must be 'annotation'")
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ result = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ aggregation_key=key,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = yield self.store.get_events_as_list(
+ [c["event_id"] for c in result.chunk]
+ )
+
+ now = self.clock.time_msec()
+ events = yield self._event_serializer.serialize_events(events, now)
+
+ return_value = result.to_dict()
+ return_value["chunk"] = events
+
+ defer.returnValue((200, return_value))
+
+
+def register_servlets(hs, http_server):
+ RelationSendServlet(hs).register(http_server)
+ RelationPaginationServlet(hs).register(http_server)
+ RelationAggregationPaginationServlet(hs).register(http_server)
+ RelationAggregationGroupPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index bdffa97805..8569677355 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -444,6 +444,9 @@ class MediaRepository(object):
)
return
+ if thumbnailer.transpose_method is not None:
+ m_width, m_height = thumbnailer.transpose()
+
if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
@@ -578,6 +581,12 @@ class MediaRepository(object):
)
return
+ if thumbnailer.transpose_method is not None:
+ m_width, m_height = yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
+ thumbnailer.transpose
+ )
+
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {}
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a4b26c2587..3efd0d80fc 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -20,6 +20,17 @@ import PIL.Image as Image
logger = logging.getLogger(__name__)
+EXIF_ORIENTATION_TAG = 0x0112
+EXIF_TRANSPOSE_MAPPINGS = {
+ 2: Image.FLIP_LEFT_RIGHT,
+ 3: Image.ROTATE_180,
+ 4: Image.FLIP_TOP_BOTTOM,
+ 5: Image.TRANSPOSE,
+ 6: Image.ROTATE_270,
+ 7: Image.TRANSVERSE,
+ 8: Image.ROTATE_90
+}
+
class Thumbnailer(object):
@@ -31,6 +42,30 @@ class Thumbnailer(object):
def __init__(self, input_path):
self.image = Image.open(input_path)
self.width, self.height = self.image.size
+ self.transpose_method = None
+ try:
+ # We don't use ImageOps.exif_transpose since it crashes with big EXIF
+ image_exif = self.image._getexif()
+ if image_exif is not None:
+ image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
+ self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
+ except Exception as e:
+ # A lot of parsing errors can happen when parsing EXIF
+ logger.info("Error parsing image EXIF information: %s", e)
+
+ def transpose(self):
+ """Transpose the image using its EXIF Orientation tag
+
+ Returns:
+ Tuple[int, int]: (width, height) containing the new image size in pixels.
+ """
+ if self.transpose_method is not None:
+ self.image = self.image.transpose(self.transpose_method)
+ self.width, self.height = self.image.size
+ self.transpose_method = None
+ # We don't need EXIF any more
+ self.image.info["exif"] = None
+ return self.image.size
def aspect(self, max_width, max_height):
"""Calculate the largest size that preserves aspect ratio which
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c432041b4e..7522d3fd57 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,6 +49,7 @@ from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
from .rejections import RejectionsStore
+from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
@@ -99,6 +100,7 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
+ RelationsStore,
):
def __init__(self, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7a7f841c6c..881d6d0126 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1325,6 +1325,9 @@ class EventsStore(
txn, event.room_id, event.redacts
)
+ # Remove from relations table.
+ self._handle_redaction(txn, event.redacts)
+
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -1351,6 +1354,8 @@ class EventsStore(
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
+ self._handle_event_relations(txn, event)
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1655,10 +1660,11 @@ class EventsStore(
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1673,11 +1679,12 @@ class EventsStore(
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
@@ -1698,10 +1705,11 @@ class EventsStore(
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1716,11 +1724,12 @@ class EventsStore(
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
new file mode 100644
index 0000000000..493abe405e
--- /dev/null
+++ b/synapse/storage/relations.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from twisted.internet import defer
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.stream import generate_pagination_where_clause
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s
+class PaginationChunk(object):
+ """Returned by relation pagination APIs.
+
+ Attributes:
+ chunk (list): The rows returned by pagination
+ next_batch (Any|None): Token to fetch next set of results with, if
+ None then there are no more results.
+ prev_batch (Any|None): Token to fetch previous set of results with, if
+ None then there are no previous results.
+ """
+
+ chunk = attr.ib()
+ next_batch = attr.ib(default=None)
+ prev_batch = attr.ib(default=None)
+
+ def to_dict(self):
+ d = {"chunk": self.chunk}
+
+ if self.next_batch:
+ d["next_batch"] = self.next_batch.to_string()
+
+ if self.prev_batch:
+ d["prev_batch"] = self.prev_batch.to_string()
+
+ return d
+
+
+@attr.s(frozen=True, slots=True)
+class RelationPaginationToken(object):
+ """Pagination token for relation pagination API.
+
+ As the results are order by topological ordering, we can use the
+ `topological_ordering` and `stream_ordering` fields of the events at the
+ boundaries of the chunk as pagination tokens.
+
+ Attributes:
+ topological (int): The topological ordering of the boundary event
+ stream (int): The stream ordering of the boundary event.
+ """
+
+ topological = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ t, s = string.split("-")
+ return RelationPaginationToken(int(t), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.topological, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+@attr.s(frozen=True, slots=True)
+class AggregationPaginationToken(object):
+ """Pagination token for relation aggregation pagination API.
+
+ As the results are order by count and then MAX(stream_ordering) of the
+ aggregation groups, we can just use them as our pagination token.
+
+ Attributes:
+ count (int): The count of relations in the boundar group.
+ stream (int): The MAX stream ordering in the boundary group.
+ """
+
+ count = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ c, s = string.split("-")
+ return AggregationPaginationToken(int(c), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.count, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+ @cached(tree=True)
+ def get_relations_for_event(
+ self,
+ event_id,
+ relation_type=None,
+ event_type=None,
+ aggregation_key=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of relations for an event, ordered by topological ordering.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ relation_type (str|None): Only fetch events with this relation
+ type, if given.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ aggregation_key (str|None): Only fetch events with this aggregation
+ key, if given.
+ limit (int): Only fetch the most recent `limit` events.
+ direction (str): Whether to fetch the most recent first (`"b"`) or
+ the oldest first (`"f"`).
+ from_token (RelationPaginationToken|None): Fetch rows from the given
+ token, or from the start if None.
+ to_token (RelationPaginationToken|None): Fetch rows up to the given
+ token, or up to the end if None.
+
+ Returns:
+ Deferred[PaginationChunk]: List of event IDs that match relations
+ requested. The rows are of the form `{"event_id": "..."}`.
+ """
+
+ where_clause = ["relates_to_id = ?"]
+ where_args = [event_id]
+
+ if relation_type is not None:
+ where_clause.append("relation_type = ?")
+ where_args.append(relation_type)
+
+ if event_type is not None:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ if aggregation_key:
+ where_clause.append("aggregation_key = ?")
+ where_args.append(aggregation_key)
+
+ pagination_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ where_clause.append(pagination_clause)
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT event_id, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+
+ def _get_recent_references_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ events = []
+ for row in txn:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[1]
+ last_stream_id = row[2]
+
+ next_batch = None
+ if len(events) > limit and last_topo_id and last_stream_id:
+ next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_recent_references_for_event", _get_recent_references_for_event_txn
+ )
+
+ @cached(tree=True)
+ def get_aggregation_groups_for_event(
+ self,
+ event_id,
+ event_type=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of annotations on the event, grouped by event type and
+ aggregation key, sorted by count.
+
+ This is used e.g. to get the what and how many reactions have happend
+ on an event.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ limit (int): Only fetch the `limit` groups.
+ direction (str): Whether to fetch the highest count first (`"b"`) or
+ the lowest count first (`"f"`).
+ from_token (AggregationPaginationToken|None): Fetch rows from the
+ given token, or from the start if None.
+ to_token (AggregationPaginationToken|None): Fetch rows up to the
+ given token, or up to the end if None.
+
+
+ Returns:
+ Deferred[PaginationChunk]: List of groups of annotations that
+ match. Each row is a dict with `type`, `key` and `count` fields.
+ """
+
+ where_clause = ["relates_to_id = ?", "relation_type = ?"]
+ where_args = [event_id, RelationTypes.ANNOTATION]
+
+ if event_type:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ having_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("COUNT(*)", "MAX(stream_ordering)"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ if having_clause:
+ having_clause = "HAVING " + having_clause
+ else:
+ having_clause = ""
+
+ sql = """
+ SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE {where_clause}
+ GROUP BY relation_type, type, aggregation_key
+ {having_clause}
+ ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ LIMIT ?
+ """.format(
+ where_clause=" AND ".join(where_clause),
+ order=order,
+ having_clause=having_clause,
+ )
+
+ def _get_aggregation_groups_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ next_batch = None
+ events = []
+ for row in txn:
+ events.append({"type": row[0], "key": row[1], "count": row[2]})
+ next_batch = AggregationPaginationToken(row[2], row[3])
+
+ if len(events) <= limit:
+ next_batch = None
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ )
+
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
+ """Get the most recent edit (if any) that has happened for the given
+ event.
+
+ Correctly handles checking whether edits were allowed to happen.
+
+ Args:
+ event_id (str): The original event ID
+
+ Returns:
+ Deferred[EventBase|None]: Returns the most recent edit, if any.
+ """
+
+ # We only allow edits for `m.room.message` events that have the same sender
+ # and event type. We can't assert these things during regular event auth so
+ # we have to do the checks post hoc.
+
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
+ sql = """
+ SELECT edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+ LIMIT 1
+ """
+
+ def _get_applicable_edit_txn(txn):
+ txn.execute(
+ sql, (event_id, RelationTypes.REPLACE,)
+ )
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ edit_id = yield self.runInteraction(
+ "get_applicable_edit", _get_applicable_edit_txn
+ )
+
+ if not edit_id:
+ return
+
+ edit_event = yield self.get_event(edit_id, allow_none=True)
+ defer.returnValue(edit_event)
+
+
+class RelationsStore(RelationsWorkerStore):
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
+
+ Args:
+ txn
+ event (EventBase)
+ """
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
+
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
+ # Unknown relation type
+ return
+
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
+
+ aggregation_key = relation.get("key")
+
+ self._simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
+ )
+
+ txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+ )
+
+ if rel_type == RelationTypes.REPLACE:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
+
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
+
+ self._simple_delete_txn(
+ txn,
+ table="event_relations",
+ keyvalues={
+ "event_id": redacted_event_id,
+ }
+ )
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql
new file mode 100644
index 0000000000..134862b870
--- /dev/null
+++ b/synapse/storage/schema/delta/54/relations.sql
@@ -0,0 +1,27 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks related events, like reactions, replies, edits, etc. Note that things
+-- in this table are not necessarily "valid", e.g. it may contain edits from
+-- people who don't have power to edit other peoples events.
+CREATE TABLE IF NOT EXISTS event_relations (
+ event_id TEXT NOT NULL,
+ relates_to_id TEXT NOT NULL,
+ relation_type TEXT NOT NULL,
+ aggregation_key TEXT
+);
+
+CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id);
+CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d105b6b17d..529ad4ea79 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -64,59 +64,135 @@ _EventDictReturn = namedtuple(
)
-def lower_bound(token, engine, inclusive=False):
- inclusive = "=" if inclusive else ""
- if token.topological is None:
- return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
- else:
- if isinstance(engine, PostgresEngine):
- # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
- # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
- # use the later form when running against postgres.
- return "((%d,%d) <%s (%s,%s))" % (
- token.topological,
- token.stream,
- inclusive,
- "topological_ordering",
- "stream_ordering",
+def generate_pagination_where_clause(
+ direction, column_names, from_token, to_token, engine,
+):
+ """Creates an SQL expression to bound the columns by the pagination
+ tokens.
+
+ For example creates an SQL expression like:
+
+ (6, 7) >= (topological_ordering, stream_ordering)
+ AND (5, 3) < (topological_ordering, stream_ordering)
+
+ would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
+
+ Note that tokens are considered to be after the row they are in, e.g. if
+ a row A has a token T, then we consider A to be before T. This convention
+ is important when figuring out inequalities for the generated SQL, and
+ produces the following result:
+ - If paginating forwards then we exclude any rows matching the from
+ token, but include those that match the to token.
+ - If paginating backwards then we include any rows matching the from
+ token, but include those that match the to token.
+
+ Args:
+ direction (str): Whether we're paginating backwards("b") or
+ forwards ("f").
+ column_names (tuple[str, str]): The column names to bound. Must *not*
+ be user defined as these get inserted directly into the SQL
+ statement without escapes.
+ from_token (tuple[int, int]|None): The start point for the pagination.
+ This is an exclusive minimum bound if direction is "f", and an
+ inclusive maximum bound if direction is "b".
+ to_token (tuple[int, int]|None): The endpoint point for the pagination.
+ This is an inclusive maximum bound if direction is "f", and an
+ exclusive minimum bound if direction is "b".
+ engine: The database engine to generate the clauses for
+
+ Returns:
+ str: The sql expression
+ """
+ assert direction in ("b", "f")
+
+ where_clause = []
+ if from_token:
+ where_clause.append(
+ _make_generic_sql_bound(
+ bound=">=" if direction == "b" else "<",
+ column_names=column_names,
+ values=from_token,
+ engine=engine,
)
- return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
- token.topological,
- "topological_ordering",
- token.topological,
- "topological_ordering",
- token.stream,
- inclusive,
- "stream_ordering",
- )
-
-
-def upper_bound(token, engine, inclusive=True):
- inclusive = "=" if inclusive else ""
- if token.topological is None:
- return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
- else:
- if isinstance(engine, PostgresEngine):
- # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
- # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
- # use the later form when running against postgres.
- return "((%d,%d) >%s (%s,%s))" % (
- token.topological,
- token.stream,
- inclusive,
- "topological_ordering",
- "stream_ordering",
+ )
+
+ if to_token:
+ where_clause.append(
+ _make_generic_sql_bound(
+ bound="<" if direction == "b" else ">=",
+ column_names=column_names,
+ values=to_token,
+ engine=engine,
)
- return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
- token.topological,
- "topological_ordering",
- token.topological,
- "topological_ordering",
- token.stream,
- inclusive,
- "stream_ordering",
)
+ return " AND ".join(where_clause)
+
+
+def _make_generic_sql_bound(bound, column_names, values, engine):
+ """Create an SQL expression that bounds the given column names by the
+ values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
+
+ Only works with two columns.
+
+ Older versions of SQLite don't support that syntax so we have to expand it
+ out manually.
+
+ Args:
+ bound (str): The comparison operator to use. One of ">", "<", ">=",
+ "<=", where the values are on the left and columns on the right.
+ names (tuple[str, str]): The column names. Must *not* be user defined
+ as these get inserted directly into the SQL statement without
+ escapes.
+ values (tuple[int|None, int]): The values to bound the columns by. If
+ the first value is None then only creates a bound on the second
+ column.
+ engine: The database engine to generate the SQL for
+
+ Returns:
+ str
+ """
+
+ assert(bound in (">", "<", ">=", "<="))
+
+ name1, name2 = column_names
+ val1, val2 = values
+
+ if val1 is None:
+ val2 = int(val2)
+ return "(%d %s %s)" % (val2, bound, name2)
+
+ val1 = int(val1)
+ val2 = int(val2)
+
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+ # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) %s (%s,%s))" % (
+ val1, val2,
+ bound,
+ name1, name2,
+ )
+
+ # We want to generate queries of e.g. the form:
+ #
+ # (val1 < name1 OR (val1 = name1 AND val2 <= name2))
+ #
+ # which is equivalent to (val1, val2) < (name1, name2)
+
+ return """(
+ {val1:d} {strict_bound} {name1}
+ OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
+ )""".format(
+ name1=name1,
+ val1=val1,
+ name2=name2,
+ val2=val2,
+ strict_bound=bound[0], # The first bound must always be strict equality here
+ bound=bound,
+ )
+
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
@@ -762,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(from_token, self.database_engine)
- if to_token:
- bounds = "%s AND %s" % (
- bounds,
- lower_bound(to_token, self.database_engine),
- )
else:
order = "ASC"
- bounds = lower_bound(from_token, self.database_engine)
- if to_token:
- bounds = "%s AND %s" % (
- bounds,
- upper_bound(to_token, self.database_engine),
- )
+
+ bounds = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=from_token,
+ to_token=to_token,
+ engine=self.database_engine,
+ )
filter_clause, filter_args = filter_to_clause(event_filter)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1c253d0579..5ffba2ca7a 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -228,3 +228,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_register_not_support_user(self):
res = self.get_success(self.handler.register(localpart='user'))
self.assertFalse(self.store.is_support_user(res[0]))
+
+ def test_invalid_user_id_length(self):
+ invalid_user_id = "x" * 256
+ self.get_failure(
+ self.handler.register(localpart=invalid_user_id),
+ SynapseError
+ )
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 6220172cde..5f75ad7579 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +25,7 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import Membership
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v1 import login, profile, room
from tests import unittest
@@ -936,3 +937,70 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request("GET", self.url, access_token=tok)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
+
+
+class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["allow_per_room_profiles"] = False
+ self.hs = self.setup_test_homeserver(config=config)
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("test", "test")
+ self.tok = self.login("test", "test")
+
+ # Set a profile for the test user
+ self.displayname = "test user"
+ data = {
+ "displayname": self.displayname,
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
+ request_data,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_per_room_profile_forbidden(self):
+ data = {
+ "membership": "join",
+ "displayname": "other test user"
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ self.room_id, self.user_id,
+ ),
+ request_data,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res_displayname = channel.json_body["content"]["displayname"]
+ self.assertEqual(res_displayname, self.displayname, channel.result)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index ad7d476401..b9ef46e8fb 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -92,7 +92,14 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(self.recaptcha_attempts), 1)
self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
- # Now we have fufilled the recaptcha fallback step, we can then send a
+ # also complete the dummy auth
+ request, channel = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ )
+ self.render(request)
+
+ # Now we should have fufilled a complete auth flow, including
+ # the recaptcha fallback step, we can then send a
# request to the register API with the session in the authdict.
request, channel = self.make_request(
"POST", "register", {"auth": {"session": session}}
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
new file mode 100644
index 0000000000..3d040cf118
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -0,0 +1,539 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import json
+
+import six
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import register, relations
+
+from tests import unittest
+
+
+class RelationsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ relations.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ ]
+ hijack_auth = False
+
+ def make_homeserver(self, reactor, clock):
+ # We need to enable msc1849 support for aggregations
+ config = self.default_config()
+ config["experimental_msc1849_support_enabled"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id, self.user_token = self._create_user("alice")
+ self.user2_id, self.user2_token = self._create_user("bob")
+
+ self.room = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ self.helper.join(self.room, user=self.user2_id, tok=self.user2_token)
+ res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
+ self.parent_id = res["event_id"]
+
+ def test_send_relation(self):
+ """Tests that sending a relation using the new /send_relation works
+ creates the right shape of event.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assert_dict(
+ {
+ "type": "m.reaction",
+ "sender": self.user_id,
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "key": u"👍",
+ "rel_type": RelationTypes.ANNOTATION,
+ }
+ },
+ },
+ channel.json_body,
+ )
+
+ def test_deny_membership(self):
+ """Test that we deny relations on membership events
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_basic_paginate_relations(self):
+ """Tests that calling pagination API corectly the latest relations.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the full
+ # relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
+ channel.json_body["chunk"][0],
+ )
+
+ # Make sure next_batch has something in it that looks like it could be a
+ # valid token.
+ self.assertIsInstance(
+ channel.json_body.get("next_batch"), six.string_types, channel.json_body
+ )
+
+ def test_repeated_paginate_relations(self):
+ """Test that if we paginate using a limit and tokens then we get the
+ expected events.
+ """
+
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+ self.assertEquals(200, channel.code, channel.json_body)
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ prev_token = None
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
+ % (self.room, self.parent_id, from_token),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ def test_aggregation_pagination_groups(self):
+ """Test that we can paginate annotation groups correctly.
+ """
+
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
+ sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1}
+ for key in itertools.chain.from_iterable(
+ itertools.repeat(key, num) for key, num in sent_groups.items()
+ ):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=key,
+ access_token=access_tokens[idx],
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ idx += 1
+ idx %= len(access_tokens)
+
+ prev_token = None
+ found_groups = {}
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
+ % (self.room, self.parent_id, from_token),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ for groups in channel.json_body["chunk"]:
+ # We only expect reactions
+ self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+ # We should only see each key once
+ self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+ found_groups[groups["key"]] = groups["count"]
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ self.assertEquals(sent_groups, found_groups)
+
+ def test_aggregation_pagination_within_group(self):
+ """Test that we can paginate within an annotation group.
+ """
+
+ expected_event_ids = []
+ for _ in range(10):
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", key=u"👍"
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ expected_event_ids.append(channel.json_body["event_id"])
+
+ # Also send a different type of reaction so that we test we don't see it
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ prev_token = None
+ found_event_ids = []
+ encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8"))
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + prev_token
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s"
+ "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
+ % (
+ self.room,
+ self.parent_id,
+ RelationTypes.ANNOTATION,
+ encoded_key,
+ from_token,
+ ),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ def test_aggregation(self):
+ """Test that annotations get correctly aggregated.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body,
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ )
+
+ def test_aggregation_redactions(self):
+ """Test that annotations get correctly aggregated after a redaction.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+ to_redact_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Now lets redact one of the 'a' reactions
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
+ access_token=self.user_token,
+ content={},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body,
+ {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+ )
+
+ def test_aggregation_must_be_annotation(self):
+ """Test that aggregations must be annotations.
+ """
+
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
+ % (self.room, self.parent_id, RelationTypes.REPLACE),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(400, channel.code, channel.json_body)
+
+ def test_aggregation_get_event(self):
+ """Test that annotations and references get correctly bundled when
+ getting the parent event.
+ """
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply_2 = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ RelationTypes.REFERENCE: {
+ "chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
+ },
+ },
+ )
+
+ def test_edit(self):
+ """Test that a simple edit works.
+ """
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ )
+
+ def test_multi_edit(self):
+ """Test that multiple edits, including attempts by people who
+ shouldn't be allowed, are correctly handled.
+ """
+
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message.WRONG_TYPE",
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ request, channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEquals(channel.json_body["content"], new_body)
+
+ self.assertEquals(
+ channel.json_body["unsigned"].get("m.relations"),
+ {RelationTypes.REPLACE: {"event_id": edit_event_id}},
+ )
+
+ def _send_relation(
+ self, relation_type, event_type, key=None, content={}, access_token=None
+ ):
+ """Helper function to send a relation pointing at `self.parent_id`
+
+ Args:
+ relation_type (str): One of `RelationTypes`
+ event_type (str): The type of the event to create
+ key (str|None): The aggregation key used for m.annotation relation
+ type.
+ content(dict|None): The content of the created event.
+ access_token (str|None): The access token used to send the relation,
+ defaults to `self.user_token`
+
+ Returns:
+ FakeChannel
+ """
+ if not access_token:
+ access_token = self.user_token
+
+ query = ""
+ if key:
+ query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
+ % (self.room, self.parent_id, relation_type, event_type, query),
+ json.dumps(content).encode("utf-8"),
+ access_token=access_token,
+ )
+ self.render(request)
+ return channel
+
+ def _create_user(self, localpart):
+ user_id = self.register_user(localpart, "abc123")
+ access_token = self.login(localpart, "abc123")
+
+ return user_id, access_token
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index f412985d2c..52739fbabc 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -59,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
for flow in channel.json_body["flows"]:
self.assertIsInstance(flow["stages"], list)
self.assertTrue(len(flow["stages"]) > 0)
- self.assertEquals(flow["stages"][-1], "m.login.terms")
+ self.assertTrue("m.login.terms" in flow["stages"])
expected_params = {
"m.login.terms": {
|