diff options
44 files changed, 1815 insertions, 122 deletions
diff --git a/INSTALL.md b/INSTALL.md index b88d826f6c..1934593148 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -35,7 +35,7 @@ virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate pip install --upgrade pip pip install --upgrade setuptools -pip install matrix-synapse[all] +pip install matrix-synapse ``` This will download Synapse from [PyPI](https://pypi.org/project/matrix-synapse) @@ -48,7 +48,7 @@ update flag: ``` source ~/synapse/env/bin/activate -pip install -U matrix-synapse[all] +pip install -U matrix-synapse ``` Before you can start Synapse, you will need to generate a configuration diff --git a/changelog.d/5039.bugfix b/changelog.d/5039.bugfix new file mode 100644 index 0000000000..212cff7ae8 --- /dev/null +++ b/changelog.d/5039.bugfix @@ -0,0 +1 @@ +Fix image orientation when generating thumbnails (needs pillow>=4.3.0). Contributed by Pau Rodriguez-Estivill. diff --git a/changelog.d/5174.bugfix b/changelog.d/5174.bugfix new file mode 100644 index 0000000000..0f26d46b2c --- /dev/null +++ b/changelog.d/5174.bugfix @@ -0,0 +1 @@ +Re-order stages in registration flows such that msisdn and email verification are done last. diff --git a/changelog.d/5177.bugfix b/changelog.d/5177.bugfix new file mode 100644 index 0000000000..c2f1644ae5 --- /dev/null +++ b/changelog.d/5177.bugfix @@ -0,0 +1 @@ +Fix 3pid guest invites. diff --git a/changelog.d/5191.misc b/changelog.d/5191.misc new file mode 100644 index 0000000000..e0615fec9c --- /dev/null +++ b/changelog.d/5191.misc @@ -0,0 +1 @@ +Make generating SQL bounds for pagination generic. diff --git a/changelog.d/5197.misc b/changelog.d/5197.misc new file mode 100644 index 0000000000..fca1d86b2e --- /dev/null +++ b/changelog.d/5197.misc @@ -0,0 +1 @@ +Stop telling people to install the optional dependencies by default. diff --git a/changelog.d/5198.bugfix b/changelog.d/5198.bugfix new file mode 100644 index 0000000000..c6b156f17d --- /dev/null +++ b/changelog.d/5198.bugfix @@ -0,0 +1 @@ +Prevent registration for user ids that are to long to fit into a state key. Contributed by Reid Anderson. \ No newline at end of file diff --git a/changelog.d/5209.feature b/changelog.d/5209.feature new file mode 100644 index 0000000000..747098c166 --- /dev/null +++ b/changelog.d/5209.feature @@ -0,0 +1 @@ +Add experimental support for relations (aka reactions and edits). diff --git a/changelog.d/5210.feature b/changelog.d/5210.feature new file mode 100644 index 0000000000..a7476bf9b9 --- /dev/null +++ b/changelog.d/5210.feature @@ -0,0 +1 @@ +Add a new room version which uses a new event ID format. diff --git a/changelog.d/5211.feature b/changelog.d/5211.feature new file mode 100644 index 0000000000..747098c166 --- /dev/null +++ b/changelog.d/5211.feature @@ -0,0 +1 @@ +Add experimental support for relations (aka reactions and edits). diff --git a/changelog.d/5219.bugfix b/changelog.d/5219.bugfix new file mode 100644 index 0000000000..c1e17adc5d --- /dev/null +++ b/changelog.d/5219.bugfix @@ -0,0 +1 @@ +Fix error handling for rooms whose versions are unknown. diff --git a/debian/test/.gitignore b/debian/test/.gitignore new file mode 100644 index 0000000000..95eda73fcc --- /dev/null +++ b/debian/test/.gitignore @@ -0,0 +1,2 @@ +.vagrant +*.log diff --git a/debian/test/provision.sh b/debian/test/provision.sh new file mode 100644 index 0000000000..a5c7f59712 --- /dev/null +++ b/debian/test/provision.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# +# provisioning script for vagrant boxes for testing the matrix-synapse debs. +# +# Will install the most recent matrix-synapse-py3 deb for this platform from +# the /debs directory. + +set -e + +apt-get update +apt-get install -y lsb-release + +deb=`ls /debs/matrix-synapse-py3_*+$(lsb_release -cs)*.deb | sort | tail -n1` + +debconf-set-selections <<EOF +matrix-synapse matrix-synapse/report-stats boolean false +matrix-synapse matrix-synapse/server-name string localhost:18448 +EOF + +dpkg -i "$deb" + +sed -i -e '/port: 8...$/{s/8448/18448/; s/8008/18008/}' -e '$aregistration_shared_secret: secret' /etc/matrix-synapse/homeserver.yaml +systemctl restart matrix-synapse diff --git a/debian/test/stretch/Vagrantfile b/debian/test/stretch/Vagrantfile new file mode 100644 index 0000000000..d8eff6fe11 --- /dev/null +++ b/debian/test/stretch/Vagrantfile @@ -0,0 +1,13 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +ver = `cd ../../..; dpkg-parsechangelog -S Version`.strip() + +Vagrant.configure("2") do |config| + config.vm.box = "debian/stretch64" + + config.vm.synced_folder ".", "/vagrant", disabled: true + config.vm.synced_folder "../../../../debs", "/debs", type: "nfs" + + config.vm.provision "shell", path: "../provision.sh" +end diff --git a/debian/test/xenial/Vagrantfile b/debian/test/xenial/Vagrantfile new file mode 100644 index 0000000000..189236da17 --- /dev/null +++ b/debian/test/xenial/Vagrantfile @@ -0,0 +1,10 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : + +Vagrant.configure("2") do |config| + config.vm.box = "ubuntu/xenial64" + + config.vm.synced_folder ".", "/vagrant", disabled: true + config.vm.synced_folder "../../../../debs", "/debs" + config.vm.provision "shell", path: "../provision.sh" +end diff --git a/docs/postgres.rst b/docs/postgres.rst index f7ebbed0c3..e81e10403f 100644 --- a/docs/postgres.rst +++ b/docs/postgres.rst @@ -3,6 +3,28 @@ Using Postgres Postgres version 9.4 or later is known to work. +Install postgres client libraries +================================= + +Synapse will require the python postgres client library in order to connect to +a postgres database. + +* If you are using the `matrix.org debian/ubuntu + packages <../INSTALL.md#matrixorg-packages>`_, + the necessary libraries will already be installed. + +* For other pre-built packages, please consult the documentation from the + relevant package. + +* If you installed synapse `in a virtualenv + <../INSTALL.md#installing-from-source>`_, you can install the library with:: + + ~/synapse/env/bin/pip install matrix-synapse[postgres] + + (substituting the path to your virtualenv for ``~/synapse/env``, if you used a + different path). You will require the postgres development files. These are in + the ``libpq-dev`` package on Debian-derived distributions. + Set up database =============== @@ -26,29 +48,6 @@ encoding use, e.g.:: This would create an appropriate database named ``synapse`` owned by the ``synapse_user`` user (which must already exist). -Set up client in Debian/Ubuntu -=========================== - -Postgres support depends on the postgres python connector ``psycopg2``. In the -virtual env:: - - sudo apt-get install libpq-dev - pip install psycopg2 - -Set up client in RHEL/CentOs 7 -============================== - -Make sure you have the appropriate version of postgres-devel installed. For a -postgres 9.4, use the postgres 9.4 packages from -[here](https://wiki.postgresql.org/wiki/YUM_Installation). - -As with Debian/Ubuntu, postgres support depends on the postgres python connector -``psycopg2``. In the virtual env:: - - sudo yum install postgresql-devel libpqxx-devel.x86_64 - export PATH=/usr/pgsql-9.4/bin/:$PATH - pip install psycopg2 - Tuning Postgres =============== diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8547a63535..6b347b1749 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -23,6 +23,9 @@ MAX_DEPTH = 2**63 - 1 # the maximum length for a room alias is 255 characters MAX_ALIAS_LENGTH = 255 +# the maximum length for a user id is 255 characters +MAX_USERID_LENGTH = 255 + class Membership(object): @@ -116,3 +119,11 @@ class UserTypes(object): """ SUPPORT = "support" ALL_USER_TYPES = (SUPPORT,) + + +class RelationTypes(object): + """The types of relations known to this server. + """ + ANNOTATION = "m.annotation" + REPLACE = "m.replace" + REFERENCE = "m.reference" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index ff89259dec..e91697049c 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -328,9 +328,23 @@ class RoomKeysVersionError(SynapseError): self.current_version = current_version +class UnsupportedRoomVersionError(SynapseError): + """The client's request to create a room used a room version that the server does + not support.""" + def __init__(self): + super(UnsupportedRoomVersionError, self).__init__( + code=400, + msg="Homeserver does not support this room version", + errcode=Codes.UNSUPPORTED_ROOM_VERSION, + ) + + class IncompatibleRoomVersionError(SynapseError): - """A server is trying to join a room whose version it does not support.""" + """A server is trying to join a room whose version it does not support. + Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join + failing. + """ def __init__(self, room_version): super(IncompatibleRoomVersionError, self).__init__( code=400, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index e77abe1040..485b3d0237 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -19,13 +19,15 @@ class EventFormatVersions(object): """This is an internal enum for tracking the version of the event format, independently from the room version. """ - V1 = 1 # $id:server format - V2 = 2 # MSC1659-style $hash format: introduced for room v3 + V1 = 1 # $id:server event id format + V2 = 2 # MSC1659-style $hash event id format: introduced for room v3 + V3 = 3 # MSC1884-style $hash format: introduced for room v4 KNOWN_EVENT_FORMAT_VERSIONS = { EventFormatVersions.V1, EventFormatVersions.V2, + EventFormatVersions.V3, } @@ -75,6 +77,12 @@ class RoomVersions(object): EventFormatVersions.V2, StateResolutionVersions.V2, ) + EVENTID_NOSLASH_TEST = RoomVersion( + "eventid-noslash-test", + RoomDisposition.UNSTABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + ) # the version we will give rooms which are created on this server @@ -87,5 +95,6 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V2, RoomVersions.V3, RoomVersions.STATE_V2_TEST, + RoomVersions.EVENTID_NOSLASH_TEST, ) } # type: dict[str, RoomVersion] diff --git a/synapse/config/server.py b/synapse/config/server.py index 1b8968608e..f34aa42afa 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -101,6 +101,11 @@ class ServerConfig(Config): "block_non_admin_invites", False, ) + # Whether to enable experimental MSC1849 (aka relations) support + self.experimental_msc1849_support_enabled = config.get( + "experimental_msc1849_support_enabled", False, + ) + # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 12056d5be2..1edd19cc13 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -21,6 +21,7 @@ import six from unpaddedbase64 import encode_base64 +from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -335,13 +336,32 @@ class FrozenEventV2(EventBase): return self.__repr__() def __repr__(self): - return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % ( + return "<%s event_id='%s', type='%s', state_key='%s'>" % ( + self.__class__.__name__, self.event_id, self.get("type", None), self.get("state_key", None), ) +class FrozenEventV3(FrozenEventV2): + """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" + format_version = EventFormatVersions.V3 # All events of this type are V3 + + @property + def event_id(self): + # We have to import this here as otherwise we get an import loop which + # is hard to break. + from synapse.crypto.event_signing import compute_event_reference_hash + + if self._event_id: + return self._event_id + self._event_id = "$" + encode_base64( + compute_event_reference_hash(self)[1], urlsafe=True + ) + return self._event_id + + def room_version_to_event_format(room_version): """Converts a room version string to the event format @@ -350,12 +370,15 @@ def room_version_to_event_format(room_version): Returns: int + + Raises: + UnsupportedRoomVersionError if the room version is unknown """ v = KNOWN_ROOM_VERSIONS.get(room_version) if not v: - # We should have already checked version, so this should not happen - raise RuntimeError("Unrecognized room version %s" % (room_version,)) + # this can happen if support is withdrawn for a room version + raise UnsupportedRoomVersionError() return v.event_format @@ -376,6 +399,8 @@ def event_type_from_format_version(format_version): return FrozenEvent elif format_version == EventFormatVersions.V2: return FrozenEventV2 + elif format_version == EventFormatVersions.V3: + return FrozenEventV3 else: raise Exception( "No event format %r" % (format_version,) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index fba27177c7..1fe995f212 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -18,6 +18,7 @@ import attr from twisted.internet import defer from synapse.api.constants import MAX_DEPTH +from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, KNOWN_ROOM_VERSIONS, @@ -178,9 +179,8 @@ class EventBuilderFactory(object): """ v = KNOWN_ROOM_VERSIONS.get(room_version) if not v: - raise Exception( - "No event format defined for version %r" % (room_version,) - ) + # this can happen if support is withdrawn for a room version + raise UnsupportedRoomVersionError() return self.for_room_version(v, key_values) def for_room_version(self, room_version, key_values): diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a5454556cc..27a2a9ef98 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -21,7 +21,7 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -324,8 +324,12 @@ class EventClientSerializer(object): """ def __init__(self, hs): - pass + self.store = hs.get_datastore() + self.experimental_msc1849_support_enabled = ( + hs.config.experimental_msc1849_support_enabled + ) + @defer.inlineCallbacks def serialize_event(self, event, time_now, **kwargs): """Serializes a single event. @@ -337,8 +341,52 @@ class EventClientSerializer(object): Returns: Deferred[dict]: The serialized event """ - event = serialize_event(event, time_now, **kwargs) - return defer.succeed(event) + # To handle the case of presence events and the like + if not isinstance(event, EventBase): + defer.returnValue(event) + + event_id = event.event_id + serialized_event = serialize_event(event, time_now, **kwargs) + + # If MSC1849 is enabled then we need to look if thre are any relations + # we need to bundle in with the event + if self.experimental_msc1849_support_enabled: + annotations = yield self.store.get_aggregation_groups_for_event( + event_id, + ) + references = yield self.store.get_relations_for_event( + event_id, RelationTypes.REFERENCE, direction="f", + ) + + if annotations.chunk: + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.ANNOTATION] = annotations.to_dict() + + if references.chunk: + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = yield self.store.get_applicable_edit(event_id) + + if edit: + # If there is an edit replace the content, preserving existing + # relations. + + relations = event.content.get("m.relates_to") + serialized_event["content"] = edit.content.get("m.new_content", {}) + if relations: + serialized_event["content"]["m.relates_to"] = relations + else: + serialized_event["content"].pop("m.relates_to", None) + + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.REPLACE] = { + "event_id": edit.event_id, + } + + defer.returnValue(serialized_event) def serialize_events(self, events, time_now, **kwargs): """Serializes multiple events. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index df60828dba..4c28c1dc3c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( IncompatibleRoomVersionError, NotFoundError, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.crypto.event_signing import compute_event_signature @@ -198,11 +199,22 @@ class FederationServer(FederationBase): try: room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue + try: + format_ver = room_version_to_event_format(room_version) + except UnsupportedRoomVersionError: + # this can happen if support for a given room version is withdrawn, + # so that we still get events for said room. + logger.info( + "Ignoring PDU for room %s with unknown version %s", + room_id, + room_version, + ) + continue + event = event_from_pdu_json(p, format_ver) pdus_by_room.setdefault(room_id, []).append(event) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a51d11a257..e83ee24f10 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -19,7 +19,7 @@ import logging from twisted.internet import defer from synapse import types -from synapse.api.constants import LoginType +from synapse.api.constants import MAX_USERID_LENGTH, LoginType from synapse.api.errors import ( AuthError, Codes, @@ -123,6 +123,15 @@ class RegistrationHandler(BaseHandler): self.check_user_id_not_appservice_exclusive(user_id) + if len(user_id) > MAX_USERID_LENGTH: + raise SynapseError( + 400, + "User ID may not be longer than %s characters" % ( + MAX_USERID_LENGTH, + ), + Codes.INVALID_USERNAME + ) + users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ffc588d454..93ac986c86 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -944,7 +944,7 @@ class RoomMemberHandler(object): } if self.config.invite_3pid_guest: - guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest( + guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest( requester=requester, medium=medium, address=address, diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2708f5e820..fdcfb90a7e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -53,7 +53,7 @@ REQUIREMENTS = [ "pyasn1-modules>=0.0.7", "daemonize>=2.3.1", "bcrypt>=3.1.0", - "pillow>=3.1.2", + "pillow>=4.3.0", "sortedcontainers>=1.4.4", "psutil>=2.0.0", "pymacaroons>=0.13.0", diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b457c5563f..a3952506c1 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.storage.event_federation import EventFederationWorkerStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore from synapse.storage.events_worker import EventsWorkerStore +from synapse.storage.relations import RelationsWorkerStore from synapse.storage.roommember import RoomMemberWorkerStore from synapse.storage.signatures import SignatureWorkerStore from synapse.storage.state import StateGroupWorkerStore @@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore, EventsWorkerStore, SignatureWorkerStore, UserErasureWorkerStore, + RelationsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): @@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore, for row in rows: self.invalidate_caches_for_event( -token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, + row.redacts, row.relates_to, backfilled=True, ) return super(SlavedEventStore, self).process_replication_rows( @@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore, if row.type == EventsStreamEventRow.TypeId: self.invalidate_caches_for_event( token, data.event_id, data.room_id, data.type, data.state_key, - data.redacts, + data.redacts, data.relates_to, backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: @@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore, raise Exception("Unknown events stream row type %s" % (row.type, )) def invalidate_caches_for_event(self, stream_ordering, event_id, room_id, - etype, state_key, redacts, backfilled): + etype, state_key, redacts, relates_to, + backfilled): self._invalidate_get_event_cache(event_id) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -136,3 +139,8 @@ class SlavedEventStore(EventFederationWorkerStore, state_key, stream_ordering ) self.get_invited_rooms_for_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 8971a6a22e..b6ce7a7bee 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", ( "type", # str "state_key", # str, optional "redacts", # str, optional + "relates_to", # str, optional )) PresenceStreamRow = namedtuple("PresenceStreamRow", ( "user_id", # str diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index e0f6e29248..f1290d022a 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -80,11 +80,12 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional + relates_to = attr.ib() # str, optional @attr.s(slots=True, frozen=True) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3a24d31d1b..e6110ad9b1 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import ( read_marker, receipts, register, + relations, report_event, room_keys, room_upgrade_rest_servlet, @@ -115,6 +116,7 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + relations.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index fa0cedb8d4..042f636135 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -348,18 +348,22 @@ class RegisterRestServlet(RestServlet): if self.hs.config.enable_registration_captcha: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: - flows.extend([[LoginType.RECAPTCHA]]) + # Also add a dummy flow here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) # only support the email-only flow if we don't require MSISDN 3PIDs if not require_msisdn: - flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) + flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]]) if show_msisdn: # only support the MSISDN-only flow if we don't require email 3PIDs if not require_email: - flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]]) + flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) # always let users provide both MSISDN & email flows.extend([ - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], + [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY], ]) else: # only support 3PIDless registration if no 3PIDs are required @@ -382,7 +386,15 @@ class RegisterRestServlet(RestServlet): if self.hs.config.user_consent_at_registration: new_flows = [] for flow in flows: - flow.append(LoginType.TERMS) + inserted = False + # m.login.terms should go near the end but before msisdn or email auth + for i, stage in enumerate(flow): + if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN: + flow.insert(i, LoginType.TERMS) + inserted = True + break + if not inserted: + flow.append(LoginType.TERMS) flows.extend(new_flows) auth_result, params, session_id = yield self.auth_handler.check_auth( diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py new file mode 100644 index 0000000000..41e0a44936 --- /dev/null +++ b/synapse/rest/client/v2_alpha/relations.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This class implements the proposed relation APIs from MSC 1849. + +Since the MSC has not been approved all APIs here are unstable and may change at +any time to reflect changes in the MSC. +""" + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken + +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class RelationSendServlet(RestServlet): + """Helper API for sending events that have relation data. + + Example API shape to send a 👍 reaction to a room: + + POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D + {} + + { + "event_id": "$foobar" + } + """ + + PATTERN = ( + "/rooms/(?P<room_id>[^/]*)/send_relation" + "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" + ) + + def __init__(self, hs): + super(RelationSendServlet, self).__init__() + self.auth = hs.get_auth() + self.event_creation_handler = hs.get_event_creation_handler() + self.txns = HttpTransactionCache(hs) + + def register(self, http_server): + http_server.register_paths( + "POST", + client_v2_patterns(self.PATTERN + "$", releases=()), + self.on_PUT_or_POST, + ) + http_server.register_paths( + "PUT", + client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), + self.on_PUT, + ) + + def on_PUT(self, request, *args, **kwargs): + return self.txns.fetch_or_execute_request( + request, self.on_PUT_or_POST, request, *args, **kwargs + ) + + @defer.inlineCallbacks + def on_PUT_or_POST( + self, request, room_id, parent_id, relation_type, event_type, txn_id=None + ): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + if event_type == EventTypes.Member: + # Add relations to a membership is meaningless, so we just deny it + # at the CS API rather than trying to handle it correctly. + raise SynapseError(400, "Cannot send member events with relations") + + content = parse_json_object_from_request(request) + + aggregation_key = parse_string(request, "key", encoding="utf-8") + + content["m.relates_to"] = { + "event_id": parent_id, + "key": aggregation_key, + "rel_type": relation_type, + } + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + event = yield self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + +class RelationPaginationServlet(RestServlet): + """API to paginate relations on an event by topological ordering, optionally + filtered by relation type and event type. + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)" + "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + +class RelationAggregationPaginationServlet(RestServlet): + """API to paginate aggregation groups of relations, e.g. paginate the + types and counts of the reactions on the events. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id} + + { + chunk: [ + { + "type": "m.reaction", + "key": "👍", + "count": 3 + } + ] + } + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" + "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type not in (RelationTypes.ANNOTATION, None): + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + + res = yield self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + defer.returnValue((200, res.to_dict())) + + +class RelationAggregationGroupPaginationServlet(RestServlet): + """API to paginate within an aggregation group of relations, e.g. paginate + all the 👍 reactions on an event. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 + + { + chunk: [ + { + "type": "m.reaction", + "content": { + "m.relates_to": { + "rel_type": "m.annotation", + "key": "👍" + } + } + }, + ... + ] + } + """ + + PATTERNS = client_v2_patterns( + "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" + "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationGroupPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type != RelationTypes.ANNOTATION: + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=key, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + +def register_servlets(hs, http_server): + RelationSendServlet(hs).register(http_server) + RelationPaginationServlet(hs).register(http_server) + RelationAggregationPaginationServlet(hs).register(http_server) + RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index bdffa97805..8569677355 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -444,6 +444,9 @@ class MediaRepository(object): ) return + if thumbnailer.transpose_method is not None: + m_width, m_height = thumbnailer.transpose() + if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": @@ -578,6 +581,12 @@ class MediaRepository(object): ) return + if thumbnailer.transpose_method is not None: + m_width, m_height = yield logcontext.defer_to_thread( + self.hs.get_reactor(), + thumbnailer.transpose + ) + # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index a4b26c2587..3efd0d80fc 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -20,6 +20,17 @@ import PIL.Image as Image logger = logging.getLogger(__name__) +EXIF_ORIENTATION_TAG = 0x0112 +EXIF_TRANSPOSE_MAPPINGS = { + 2: Image.FLIP_LEFT_RIGHT, + 3: Image.ROTATE_180, + 4: Image.FLIP_TOP_BOTTOM, + 5: Image.TRANSPOSE, + 6: Image.ROTATE_270, + 7: Image.TRANSVERSE, + 8: Image.ROTATE_90 +} + class Thumbnailer(object): @@ -31,6 +42,30 @@ class Thumbnailer(object): def __init__(self, input_path): self.image = Image.open(input_path) self.width, self.height = self.image.size + self.transpose_method = None + try: + # We don't use ImageOps.exif_transpose since it crashes with big EXIF + image_exif = self.image._getexif() + if image_exif is not None: + image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) + self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) + except Exception as e: + # A lot of parsing errors can happen when parsing EXIF + logger.info("Error parsing image EXIF information: %s", e) + + def transpose(self): + """Transpose the image using its EXIF Orientation tag + + Returns: + Tuple[int, int]: (width, height) containing the new image size in pixels. + """ + if self.transpose_method is not None: + self.image = self.image.transpose(self.transpose_method) + self.width, self.height = self.image.size + self.transpose_method = None + # We don't need EXIF any more + self.image.info["exif"] = None + return self.image.size def aspect(self, max_width, max_height): """Calculate the largest size that preserves aspect ratio which diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c432041b4e..7522d3fd57 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -49,6 +49,7 @@ from .pusher import PusherStore from .receipts import ReceiptsStore from .registration import RegistrationStore from .rejections import RejectionsStore +from .relations import RelationsStore from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore @@ -99,6 +100,7 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, + RelationsStore, ): def __init__(self, db_conn, hs): self.hs = hs diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7a7f841c6c..881d6d0126 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1325,6 +1325,9 @@ class EventsStore( txn, event.room_id, event.redacts ) + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -1351,6 +1354,8 @@ class EventsStore( # Insert into the event_search table. self._store_guest_access_txn(txn, event) + self._handle_event_relations(txn, event) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1655,10 +1660,11 @@ class EventsStore( def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1673,11 +1679,12 @@ class EventsStore( sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " ORDER BY event_stream_ordering DESC" @@ -1698,10 +1705,11 @@ class EventsStore( def get_all_new_backfill_event_rows(txn): sql = ( "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" " ORDER BY stream_ordering ASC" " LIMIT ?" @@ -1716,11 +1724,12 @@ class EventsStore( sql = ( "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" + " state_key, redacts, relates_to_id" " FROM events AS e" " INNER JOIN ex_outlier_stream USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" " ORDER BY event_stream_ordering DESC" diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py new file mode 100644 index 0000000000..493abe405e --- /dev/null +++ b/synapse/storage/relations.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import attr + +from twisted.internet import defer + +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.storage.stream import generate_pagination_where_clause +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +logger = logging.getLogger(__name__) + + +@attr.s +class PaginationChunk(object): + """Returned by relation pagination APIs. + + Attributes: + chunk (list): The rows returned by pagination + next_batch (Any|None): Token to fetch next set of results with, if + None then there are no more results. + prev_batch (Any|None): Token to fetch previous set of results with, if + None then there are no previous results. + """ + + chunk = attr.ib() + next_batch = attr.ib(default=None) + prev_batch = attr.ib(default=None) + + def to_dict(self): + d = {"chunk": self.chunk} + + if self.next_batch: + d["next_batch"] = self.next_batch.to_string() + + if self.prev_batch: + d["prev_batch"] = self.prev_batch.to_string() + + return d + + +@attr.s(frozen=True, slots=True) +class RelationPaginationToken(object): + """Pagination token for relation pagination API. + + As the results are order by topological ordering, we can use the + `topological_ordering` and `stream_ordering` fields of the events at the + boundaries of the chunk as pagination tokens. + + Attributes: + topological (int): The topological ordering of the boundary event + stream (int): The stream ordering of the boundary event. + """ + + topological = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + t, s = string.split("-") + return RelationPaginationToken(int(t), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.topological, self.stream) + + def as_tuple(self): + return attr.astuple(self) + + +@attr.s(frozen=True, slots=True) +class AggregationPaginationToken(object): + """Pagination token for relation aggregation pagination API. + + As the results are order by count and then MAX(stream_ordering) of the + aggregation groups, we can just use them as our pagination token. + + Attributes: + count (int): The count of relations in the boundar group. + stream (int): The MAX stream ordering in the boundary group. + """ + + count = attr.ib() + stream = attr.ib() + + @staticmethod + def from_string(string): + try: + c, s = string.split("-") + return AggregationPaginationToken(int(c), int(s)) + except ValueError: + raise SynapseError(400, "Invalid token") + + def to_string(self): + return "%d-%d" % (self.count, self.stream) + + def as_tuple(self): + return attr.astuple(self) + + +class RelationsWorkerStore(SQLBaseStore): + @cached(tree=True) + def get_relations_for_event( + self, + event_id, + relation_type=None, + event_type=None, + aggregation_key=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of relations for an event, ordered by topological ordering. + + Args: + event_id (str): Fetch events that relate to this event ID. + relation_type (str|None): Only fetch events with this relation + type, if given. + event_type (str|None): Only fetch events with this event type, if + given. + aggregation_key (str|None): Only fetch events with this aggregation + key, if given. + limit (int): Only fetch the most recent `limit` events. + direction (str): Whether to fetch the most recent first (`"b"`) or + the oldest first (`"f"`). + from_token (RelationPaginationToken|None): Fetch rows from the given + token, or from the start if None. + to_token (RelationPaginationToken|None): Fetch rows up to the given + token, or up to the end if None. + + Returns: + Deferred[PaginationChunk]: List of event IDs that match relations + requested. The rows are of the form `{"event_id": "..."}`. + """ + + where_clause = ["relates_to_id = ?"] + where_args = [event_id] + + if relation_type is not None: + where_clause.append("relation_type = ?") + where_args.append(relation_type) + + if event_type is not None: + where_clause.append("type = ?") + where_args.append(event_type) + + if aggregation_key: + where_clause.append("aggregation_key = ?") + where_args.append(aggregation_key) + + pagination_clause = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if pagination_clause: + where_clause.append(pagination_clause) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + sql = """ + SELECT event_id, topological_ordering, stream_ordering + FROM event_relations + INNER JOIN events USING (event_id) + WHERE %s + ORDER BY topological_ordering %s, stream_ordering %s + LIMIT ? + """ % ( + " AND ".join(where_clause), + order, + order, + ) + + def _get_recent_references_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + last_topo_id = None + last_stream_id = None + events = [] + for row in txn: + events.append({"event_id": row[0]}) + last_topo_id = row[1] + last_stream_id = row[2] + + next_batch = None + if len(events) > limit and last_topo_id and last_stream_id: + next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_recent_references_for_event", _get_recent_references_for_event_txn + ) + + @cached(tree=True) + def get_aggregation_groups_for_event( + self, + event_id, + event_type=None, + limit=5, + direction="b", + from_token=None, + to_token=None, + ): + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id (str): Fetch events that relate to this event ID. + event_type (str|None): Only fetch events with this event type, if + given. + limit (int): Only fetch the `limit` groups. + direction (str): Whether to fetch the highest count first (`"b"`) or + the lowest count first (`"f"`). + from_token (AggregationPaginationToken|None): Fetch rows from the + given token, or from the start if None. + to_token (AggregationPaginationToken|None): Fetch rows up to the + given token, or up to the end if None. + + + Returns: + Deferred[PaginationChunk]: List of groups of annotations that + match. Each row is a dict with `type`, `key` and `count` fields. + """ + + where_clause = ["relates_to_id = ?", "relation_type = ?"] + where_args = [event_id, RelationTypes.ANNOTATION] + + if event_type: + where_clause.append("type = ?") + where_args.append(event_type) + + having_clause = generate_pagination_where_clause( + direction=direction, + column_names=("COUNT(*)", "MAX(stream_ordering)"), + from_token=attr.astuple(from_token) if from_token else None, + to_token=attr.astuple(to_token) if to_token else None, + engine=self.database_engine, + ) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + if having_clause: + having_clause = "HAVING " + having_clause + else: + having_clause = "" + + sql = """ + SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE {where_clause} + GROUP BY relation_type, type, aggregation_key + {having_clause} + ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} + LIMIT ? + """.format( + where_clause=" AND ".join(where_clause), + order=order, + having_clause=having_clause, + ) + + def _get_aggregation_groups_for_event_txn(txn): + txn.execute(sql, where_args + [limit + 1]) + + next_batch = None + events = [] + for row in txn: + events.append({"type": row[0], "key": row[1], "count": row[2]}) + next_batch = AggregationPaginationToken(row[2], row[3]) + + if len(events) <= limit: + next_batch = None + + return PaginationChunk( + chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + ) + + return self.runInteraction( + "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn + ) + + @cachedInlineCallbacks() + def get_applicable_edit(self, event_id): + """Get the most recent edit (if any) that has happened for the given + event. + + Correctly handles checking whether edits were allowed to happen. + + Args: + event_id (str): The original event ID + + Returns: + Deferred[EventBase|None]: Returns the most recent edit, if any. + """ + + # We only allow edits for `m.room.message` events that have the same sender + # and event type. We can't assert these things during regular event auth so + # we have to do the checks post hoc. + + # Fetches latest edit that has the same type and sender as the + # original, and is an `m.room.message`. + sql = """ + SELECT edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + WHERE + relates_to_id = ? + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts DESC, edit.event_id DESC + LIMIT 1 + """ + + def _get_applicable_edit_txn(txn): + txn.execute( + sql, (event_id, RelationTypes.REPLACE,) + ) + row = txn.fetchone() + if row: + return row[0] + + edit_id = yield self.runInteraction( + "get_applicable_edit", _get_applicable_edit_txn + ) + + if not edit_id: + return + + edit_event = yield self.get_event(edit_id, allow_none=True) + defer.returnValue(edit_event) + + +class RelationsStore(RelationsWorkerStore): + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events + + Args: + txn + event (EventBase) + """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return + + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return + + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return + + aggregation_key = relation.get("key") + + self._simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, + ) + + txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + ) + + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, + table="event_relations", + keyvalues={ + "event_id": redacted_event_id, + } + ) diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql new file mode 100644 index 0000000000..134862b870 --- /dev/null +++ b/synapse/storage/schema/delta/54/relations.sql @@ -0,0 +1,27 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks related events, like reactions, replies, edits, etc. Note that things +-- in this table are not necessarily "valid", e.g. it may contain edits from +-- people who don't have power to edit other peoples events. +CREATE TABLE IF NOT EXISTS event_relations ( + event_id TEXT NOT NULL, + relates_to_id TEXT NOT NULL, + relation_type TEXT NOT NULL, + aggregation_key TEXT +); + +CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id); +CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index d105b6b17d..529ad4ea79 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -64,59 +64,135 @@ _EventDictReturn = namedtuple( ) -def lower_bound(token, engine, inclusive=False): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well - # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) <%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", +def generate_pagination_where_clause( + direction, column_names, from_token, to_token, engine, +): + """Creates an SQL expression to bound the columns by the pagination + tokens. + + For example creates an SQL expression like: + + (6, 7) >= (topological_ordering, stream_ordering) + AND (5, 3) < (topological_ordering, stream_ordering) + + would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + + Args: + direction (str): Whether we're paginating backwards("b") or + forwards ("f"). + column_names (tuple[str, str]): The column names to bound. Must *not* + be user defined as these get inserted directly into the SQL + statement without escapes. + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". + engine: The database engine to generate the clauses for + + Returns: + str: The sql expression + """ + assert direction in ("b", "f") + + where_clause = [] + if from_token: + where_clause.append( + _make_generic_sql_bound( + bound=">=" if direction == "b" else "<", + column_names=column_names, + values=from_token, + engine=engine, ) - return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", - ) - - -def upper_bound(token, engine, inclusive=True): - inclusive = "=" if inclusive else "" - if token.topological is None: - return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") - else: - if isinstance(engine, PostgresEngine): - # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well - # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we - # use the later form when running against postgres. - return "((%d,%d) >%s (%s,%s))" % ( - token.topological, - token.stream, - inclusive, - "topological_ordering", - "stream_ordering", + ) + + if to_token: + where_clause.append( + _make_generic_sql_bound( + bound="<" if direction == "b" else ">=", + column_names=column_names, + values=to_token, + engine=engine, ) - return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( - token.topological, - "topological_ordering", - token.topological, - "topological_ordering", - token.stream, - inclusive, - "stream_ordering", ) + return " AND ".join(where_clause) + + +def _make_generic_sql_bound(bound, column_names, values, engine): + """Create an SQL expression that bounds the given column names by the + values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. + + Only works with two columns. + + Older versions of SQLite don't support that syntax so we have to expand it + out manually. + + Args: + bound (str): The comparison operator to use. One of ">", "<", ">=", + "<=", where the values are on the left and columns on the right. + names (tuple[str, str]): The column names. Must *not* be user defined + as these get inserted directly into the SQL statement without + escapes. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. + engine: The database engine to generate the SQL for + + Returns: + str + """ + + assert(bound in (">", "<", ">=", "<=")) + + name1, name2 = column_names + val1, val2 = values + + if val1 is None: + val2 = int(val2) + return "(%d %s %s)" % (val2, bound, name2) + + val1 = int(val1) + val2 = int(val2) + + if isinstance(engine, PostgresEngine): + # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well + # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we + # use the later form when running against postgres. + return "((%d,%d) %s (%s,%s))" % ( + val1, val2, + bound, + name1, name2, + ) + + # We want to generate queries of e.g. the form: + # + # (val1 < name1 OR (val1 = name1 AND val2 <= name2)) + # + # which is equivalent to (val1, val2) < (name1, name2) + + return """( + {val1:d} {strict_bound} {name1} + OR ({val1:d} = {name1} AND {val2:d} {bound} {name2}) + )""".format( + name1=name1, + val1=val1, + name2=name2, + val2=val2, + strict_bound=bound[0], # The first bound must always be strict equality here + bound=bound, + ) + def filter_to_clause(event_filter): # NB: This may create SQL clauses that don't optimise well (and we don't @@ -762,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args = [False, room_id] if direction == 'b': order = "DESC" - bounds = upper_bound(from_token, self.database_engine) - if to_token: - bounds = "%s AND %s" % ( - bounds, - lower_bound(to_token, self.database_engine), - ) else: order = "ASC" - bounds = lower_bound(from_token, self.database_engine) - if to_token: - bounds = "%s AND %s" % ( - bounds, - upper_bound(to_token, self.database_engine), - ) + + bounds = generate_pagination_where_clause( + direction=direction, + column_names=("topological_ordering", "stream_ordering"), + from_token=from_token, + to_token=to_token, + engine=self.database_engine, + ) filter_clause, filter_args = filter_to_clause(event_filter) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1c253d0579..5ffba2ca7a 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -228,3 +228,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_not_support_user(self): res = self.get_success(self.handler.register(localpart='user')) self.assertFalse(self.store.is_support_user(res[0])) + + def test_invalid_user_id_length(self): + invalid_user_id = "x" * 256 + self.get_failure( + self.handler.register(localpart=invalid_user_id), + SynapseError + ) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index ad7d476401..b9ef46e8fb 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -92,7 +92,14 @@ class FallbackAuthTests(unittest.HomeserverTestCase): self.assertEqual(len(self.recaptcha_attempts), 1) self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a") - # Now we have fufilled the recaptcha fallback step, we can then send a + # also complete the dummy auth + request, channel = self.make_request( + "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} + ) + self.render(request) + + # Now we should have fufilled a complete auth flow, including + # the recaptcha fallback step, we can then send a # request to the register API with the session in the authdict. request, channel = self.make_request( "POST", "register", {"auth": {"session": session}} diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py new file mode 100644 index 0000000000..3d040cf118 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -0,0 +1,539 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json + +import six + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import register, relations + +from tests import unittest + + +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + relations.register_servlets, + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + # We need to enable msc1849 support for aggregations + config = self.default_config() + config["experimental_msc1849_support_enabled"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.user_id, self.user_token = self._create_user("alice") + self.user2_id, self.user2_token = self._create_user("bob") + + self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) + self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) + res = self.helper.send(self.room, body="Hi!", tok=self.user_token) + self.parent_id = res["event_id"] + + def test_send_relation(self): + """Tests that sending a relation using the new /send_relation works + creates the right shape of event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key=u"👍") + self.assertEquals(200, channel.code, channel.json_body) + + event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, event_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "key": u"👍", + "rel_type": RelationTypes.ANNOTATION, + } + }, + }, + channel.json_body, + ) + + def test_deny_membership(self): + """Test that we deny relations on membership events + """ + channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) + self.assertEquals(400, channel.code, channel.json_body) + + def test_basic_paginate_relations(self): + """Tests that calling pagination API corectly the latest relations. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, + channel.json_body["chunk"][0], + ) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), six.string_types, channel.json_body + ) + + def test_repeated_paginate_relations(self): + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = None + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation_pagination_groups(self): + """Test that we can paginate annotation groups correctly. + """ + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {u"👍": 10, u"a": 7, u"b": 5, u"c": 3, u"d": 2, u"e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + + idx += 1 + idx %= len(access_tokens) + + prev_token = None + found_groups = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEquals(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self): + """Test that we can paginate within an annotation group. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key=u"👍" + ) + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_event_ids = [] + encoded_key = six.moves.urllib.parse.quote_plus(u"👍".encode("utf-8")) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s" + "/aggregations/%s/%s/m.reaction/%s?limit=1%s" + % ( + self.room, + self.parent_id, + RelationTypes.ANNOTATION, + encoded_key, + from_token, + ), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation(self): + """Test that annotations get correctly aggregated. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + ) + + def test_aggregation_redactions(self): + """Test that annotations get correctly aggregated after a redaction. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Now lets redact one of the 'a' reactions + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), + access_token=self.user_token, + content={}, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + def test_aggregation_must_be_annotation(self): + """Test that aggregations must be annotations. + """ + + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" + % (self.room, self.parent_id, RelationTypes.REPLACE), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(400, channel.code, channel.json_body) + + def test_aggregation_get_event(self): + """Test that annotations and references get correctly bundled when + getting the parent event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_2 = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + }, + ) + + def test_edit(self): + """Test that a simple edit works. + """ + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, + ) + + def test_multi_edit(self): + """Test that multiple edits, including attempts by people who + shouldn't be allowed, are correctly handled. + """ + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + request, channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + {RelationTypes.REPLACE: {"event_id": edit_event_id}}, + ) + + def _send_relation( + self, relation_type, event_type, key=None, content={}, access_token=None + ): + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type (str): One of `RelationTypes` + event_type (str): The type of the event to create + key (str|None): The aggregation key used for m.annotation relation + type. + content(dict|None): The content of the created event. + access_token (str|None): The access token used to send the relation, + defaults to `self.user_token` + + Returns: + FakeChannel + """ + if not access_token: + access_token = self.user_token + + query = "" + if key: + query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8")) + + request, channel = self.make_request( + "POST", + "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" + % (self.room, self.parent_id, relation_type, event_type, query), + json.dumps(content).encode("utf-8"), + access_token=access_token, + ) + self.render(request) + return channel + + def _create_user(self, localpart): + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index f412985d2c..52739fbabc 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -59,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase): for flow in channel.json_body["flows"]: self.assertIsInstance(flow["stages"], list) self.assertTrue(len(flow["stages"]) > 0) - self.assertEquals(flow["stages"][-1], "m.login.terms") + self.assertTrue("m.login.terms" in flow["stages"]) expected_params = { "m.login.terms": { |