summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6409.feature1
-rw-r--r--synapse/api/constants.py4
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/handlers/federation.py8
-rw-r--r--synapse/handlers/message.py123
-rw-r--r--synapse/storage/data_stores/main/events.py126
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql21
-rw-r--r--tests/rest/client/test_ephemeral_message.py101
8 files changed, 379 insertions, 7 deletions
diff --git a/changelog.d/6409.feature b/changelog.d/6409.feature
new file mode 100644
index 0000000000..653ff5a5ad
--- /dev/null
+++ b/changelog.d/6409.feature
@@ -0,0 +1 @@
+Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228).
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e3f086f1c3..69cef369a5 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -147,3 +147,7 @@ class EventContentFields(object):
 
     # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
     LABELS = "org.matrix.labels"
+
+    # Timestamp to delete the event after
+    # cf https://github.com/matrix-org/matrix-doc/pull/2228
+    SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7a9d711669..837fbe1582 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -490,6 +490,8 @@ class ServerConfig(Config):
             "cleanup_extremities_with_dummy_events", True
         )
 
+        self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
+
     def has_tls_listener(self) -> bool:
         return any(l["tls"] for l in self.listeners)
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d3267734f7..d9d0cd9eef 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -121,6 +121,7 @@ class FederationHandler(BaseHandler):
         self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
         self.http_client = hs.get_simple_http_client()
@@ -141,6 +142,8 @@ class FederationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
         """ Process a PDU received via a federation /send/ transaction, or
@@ -2715,6 +2718,11 @@ class FederationHandler(BaseHandler):
                 event_and_contexts, backfilled=backfilled
             )
 
+            if self._ephemeral_messages_enabled:
+                for (event, context) in event_and_contexts:
+                    # If there's an expiry timestamp on the event, schedule its expiry.
+                    self._message_handler.maybe_schedule_expiry(event)
+
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
                     yield self._notify_persisted_event(event, max_stream_id)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 3b0156f516..4f53a5f5dc 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Optional
 
 from six import iteritems, itervalues, string_types
 
@@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
 from twisted.internet.defer import succeed
+from twisted.internet.interfaces import IDelayedCall
 
 from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    Membership,
+    RelationTypes,
+    UserTypes,
+)
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -62,6 +70,17 @@ class MessageHandler(object):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
         self._event_serializer = hs.get_event_client_serializer()
+        self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+        self._is_worker_app = bool(hs.config.worker_app)
+
+        # The scheduled call to self._expire_event. None if no call is currently
+        # scheduled.
+        self._scheduled_expiry = None  # type: Optional[IDelayedCall]
+
+        if not hs.config.worker_app:
+            run_as_background_process(
+                "_schedule_next_expiry", self._schedule_next_expiry
+            )
 
     @defer.inlineCallbacks
     def get_room_data(
@@ -225,6 +244,100 @@ class MessageHandler(object):
             for user_id, profile in iteritems(users_with_profile)
         }
 
+    def maybe_schedule_expiry(self, event):
+        """Schedule the expiry of an event if there's not already one scheduled,
+        or if the one running is for an event that will expire after the provided
+        timestamp.
+
+        This function needs to invalidate the event cache, which is only possible on
+        the master process, and therefore needs to be run on there.
+
+        Args:
+            event (EventBase): The event to schedule the expiry of.
+        """
+        assert not self._is_worker_app
+
+        expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+        if not isinstance(expiry_ts, int) or event.is_state():
+            return
+
+        # _schedule_expiry_for_event won't actually schedule anything if there's already
+        # a task scheduled for a timestamp that's sooner than the provided one.
+        self._schedule_expiry_for_event(event.event_id, expiry_ts)
+
+    @defer.inlineCallbacks
+    def _schedule_next_expiry(self):
+        """Retrieve the ID and the expiry timestamp of the next event to be expired,
+        and schedule an expiry task for it.
+
+        If there's no event left to expire, set _expiry_scheduled to None so that a
+        future call to save_expiry_ts can schedule a new expiry task.
+        """
+        # Try to get the expiry timestamp of the next event to expire.
+        res = yield self.store.get_next_event_to_expire()
+        if res:
+            event_id, expiry_ts = res
+            self._schedule_expiry_for_event(event_id, expiry_ts)
+
+    def _schedule_expiry_for_event(self, event_id, expiry_ts):
+        """Schedule an expiry task for the provided event if there's not already one
+        scheduled at a timestamp that's sooner than the provided one.
+
+        Args:
+            event_id (str): The ID of the event to expire.
+            expiry_ts (int): The timestamp at which to expire the event.
+        """
+        if self._scheduled_expiry:
+            # If the provided timestamp refers to a time before the scheduled time of the
+            # next expiry task, cancel that task and reschedule it for this timestamp.
+            next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000
+            if expiry_ts < next_scheduled_expiry_ts:
+                self._scheduled_expiry.cancel()
+            else:
+                return
+
+        # Figure out how many seconds we need to wait before expiring the event.
+        now_ms = self.clock.time_msec()
+        delay = (expiry_ts - now_ms) / 1000
+
+        # callLater doesn't support negative delays, so trim the delay to 0 if we're
+        # in that case.
+        if delay < 0:
+            delay = 0
+
+        logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay)
+
+        self._scheduled_expiry = self.clock.call_later(
+            delay,
+            run_as_background_process,
+            "_expire_event",
+            self._expire_event,
+            event_id,
+        )
+
+    @defer.inlineCallbacks
+    def _expire_event(self, event_id):
+        """Retrieve and expire an event that needs to be expired from the database.
+
+        If the event doesn't exist in the database, log it and delete the expiry date
+        from the database (so that we don't try to expire it again).
+        """
+        assert self._ephemeral_events_enabled
+
+        self._scheduled_expiry = None
+
+        logger.info("Expiring event %s", event_id)
+
+        try:
+            # Expire the event if we know about it. This function also deletes the expiry
+            # date from the database in the same database transaction.
+            yield self.store.expire_event(event_id)
+        except Exception as e:
+            logger.error("Could not expire event %s: %r", event_id, e)
+
+        # Schedule the expiry of the next event to expire.
+        yield self._schedule_next_expiry()
+
 
 # The duration (in ms) after which rooms should be removed
 # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try
@@ -295,6 +408,10 @@ class EventCreationHandler(object):
                 5 * 60 * 1000,
             )
 
+        self._message_handler = hs.get_message_handler()
+
+        self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def create_event(
         self,
@@ -877,6 +994,10 @@ class EventCreationHandler(object):
             event, context=context
         )
 
+        if self._ephemeral_events_enabled:
+            # If there's an expiry timestamp on the event, schedule its expiry.
+            self._message_handler.maybe_schedule_expiry(event)
+
         yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
 
         def _notify():
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 2737a1d3ae..79c91fe284 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -130,6 +130,8 @@ class EventsStore(
         if self.hs.config.redaction_retention_period is not None:
             hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
 
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def _read_forward_extremities(self):
         def fetch(txn):
@@ -940,6 +942,12 @@ class EventsStore(
                     txn, event.event_id, labels, event.room_id, event.depth
                 )
 
+            if self._ephemeral_messages_enabled:
+                # If there's an expiry timestamp on the event, store it.
+                expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+                if isinstance(expiry_ts, int) and not event.is_state():
+                    self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
         # Insert into the room_memberships table.
         self._store_room_members_txn(
             txn,
@@ -1101,12 +1109,7 @@ class EventsStore(
         def _update_censor_txn(txn):
             for redaction_id, event_id, pruned_json in updates:
                 if pruned_json:
-                    self._simple_update_one_txn(
-                        txn,
-                        table="event_json",
-                        keyvalues={"event_id": event_id},
-                        updatevalues={"json": pruned_json},
-                    )
+                    self._censor_event_txn(txn, event_id, pruned_json)
 
                 self._simple_update_one_txn(
                     txn,
@@ -1117,6 +1120,22 @@ class EventsStore(
 
         yield self.runInteraction("_update_censor_txn", _update_censor_txn)
 
+    def _censor_event_txn(self, txn, event_id, pruned_json):
+        """Censor an event by replacing its JSON in the event_json table with the
+        provided pruned JSON.
+
+        Args:
+            txn (LoggingTransaction): The database transaction.
+            event_id (str): The ID of the event to censor.
+            pruned_json (str): The pruned JSON
+        """
+        self._simple_update_one_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event_id},
+            updatevalues={"json": pruned_json},
+        )
+
     @defer.inlineCallbacks
     def count_daily_messages(self):
         """
@@ -1957,6 +1976,101 @@ class EventsStore(
             ],
         )
 
+    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+        """Save the expiry timestamp associated with a given event ID.
+
+        Args:
+            txn (LoggingTransaction): The database transaction to use.
+            event_id (str): The event ID the expiry timestamp is associated with.
+            expiry_ts (int): The timestamp at which to expire (delete) the event.
+        """
+        return self._simple_insert_txn(
+            txn=txn,
+            table="event_expiry",
+            values={"event_id": event_id, "expiry_ts": expiry_ts},
+        )
+
+    @defer.inlineCallbacks
+    def expire_event(self, event_id):
+        """Retrieve and expire an event that has expired, and delete its associated
+        expiry timestamp. If the event can't be retrieved, delete its associated
+        timestamp so we don't try to expire it again in the future.
+
+        Args:
+             event_id (str): The ID of the event to delete.
+        """
+        # Try to retrieve the event's content from the database or the event cache.
+        event = yield self.get_event(event_id)
+
+        def delete_expired_event_txn(txn):
+            # Delete the expiry timestamp associated with this event from the database.
+            self._delete_event_expiry_txn(txn, event_id)
+
+            if not event:
+                # If we can't find the event, log a warning and delete the expiry date
+                # from the database so that we don't try to expire it again in the
+                # future.
+                logger.warning(
+                    "Can't expire event %s because we don't have it.", event_id
+                )
+                return
+
+            # Prune the event's dict then convert it to JSON.
+            pruned_json = encode_json(prune_event_dict(event.get_dict()))
+
+            # Update the event_json table to replace the event's JSON with the pruned
+            # JSON.
+            self._censor_event_txn(txn, event.event_id, pruned_json)
+
+            # We need to invalidate the event cache entry for this event because we
+            # changed its content in the database. We can't call
+            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+            # right type.
+            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            # Send that invalidation to replication so that other workers also invalidate
+            # the event cache.
+            self._send_invalidation_to_replication(
+                txn, "_get_event_cache", (event.event_id,)
+            )
+
+        yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+    def _delete_event_expiry_txn(self, txn, event_id):
+        """Delete the expiry timestamp associated with an event ID without deleting the
+        actual event.
+
+        Args:
+            txn (LoggingTransaction): The transaction to use to perform the deletion.
+            event_id (str): The event ID to delete the associated expiry timestamp of.
+        """
+        return self._simple_delete_txn(
+            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        )
+
+    def get_next_event_to_expire(self):
+        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+        table, or None if there's no more event to expire.
+
+        Returns: Deferred[Optional[Tuple[str, int]]]
+            A tuple containing the event ID as its first element and an expiry timestamp
+            as its second one, if there's at least one row in the event_expiry table.
+            None otherwise.
+        """
+
+        def get_next_event_to_expire_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, expiry_ts FROM event_expiry
+                ORDER BY expiry_ts ASC LIMIT 1
+                """
+            )
+
+            return txn.fetchone()
+
+        return self.runInteraction(
+            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        )
+
 
 AllNewEventsResult = namedtuple(
     "AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
new file mode 100644
index 0000000000..81a36a8b1d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+    event_id TEXT PRIMARY KEY,
+    expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
new file mode 100644
index 0000000000..5e9c07ebf3
--- /dev/null
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import room
+
+from tests import unittest
+
+
+class EphemeralMessageTestCase(unittest.HomeserverTestCase):
+
+    user_id = "@user:test"
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+
+        config["enable_ephemeral_messages"] = True
+
+        self.hs = self.setup_test_homeserver(config=config)
+        return self.hs
+
+    def prepare(self, reactor, clock, homeserver):
+        self.room_id = self.helper.create_room_as(self.user_id)
+
+    def test_message_expiry_no_delay(self):
+        """Tests that sending a message sent with a m.self_destruct_after field set to the
+        past results in that event being deleted right away.
+        """
+        # Send a message in the room that has expired. From here, the reactor clock is
+        # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
+        # is at 0ms the code path is the same if the event's expiry timestamp is the
+        # current timestamp.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "hello",
+                EventContentFields.SELF_DESTRUCT_AFTER: 0,
+            },
+        )
+        event_id = res["event_id"]
+
+        # Check that we can't retrieve the content of the event.
+        event_content = self.get_event(self.room_id, event_id)["content"]
+        self.assertFalse(bool(event_content), event_content)
+
+    def test_message_expiry_delay(self):
+        """Tests that sending a message with a m.self_destruct_after field set to the
+        future results in that event not being deleted right away, but advancing the
+        clock to after that expiry timestamp causes the event to be deleted.
+        """
+        # Send a message in the room that'll expire in 1s.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "hello",
+                EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000,
+            },
+        )
+        event_id = res["event_id"]
+
+        # Check that we can retrieve the content of the event before it has expired.
+        event_content = self.get_event(self.room_id, event_id)["content"]
+        self.assertTrue(bool(event_content), event_content)
+
+        # Advance the clock to after the deletion.
+        self.reactor.advance(1)
+
+        # Check that we can't retrieve the content of the event anymore.
+        event_content = self.get_event(self.room_id, event_id)["content"]
+        self.assertFalse(bool(event_content), event_content)
+
+    def get_event(self, room_id, event_id, expected_code=200):
+        url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+        request, channel = self.make_request("GET", url)
+        self.render(request)
+
+        self.assertEqual(channel.code, expected_code, channel.result)
+
+        return channel.json_body