summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/deviceinbox.py139
1 files changed, 132 insertions, 7 deletions
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 57202a5bda..988577a334 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -27,28 +27,89 @@ logger = logging.getLogger(__name__)
 class DeviceInboxStore(SQLBaseStore):
 
     @defer.inlineCallbacks
-    def add_messages_to_device_inbox(self, messages_by_user_then_device):
-        """
+    def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
+                                     remote_messages_by_destination):
+        """Used to send messages from this server.
+
         Args:
-            messages_by_user_and_device(dict):
+            sender_user_id(str): The ID of the user sending these messages.
+            local_messages_by_user_and_device(dict):
                 Dictionary of user_id to device_id to message.
+            remote_messages_by_destination(dict):
+                Dictionary of destination server_name to the EDU JSON to send.
         Returns:
             A deferred stream_id that resolves when the messages have been
             inserted.
         """
 
+        def add_messages_to_device_federation_outbox(txn, now_ms, stream_id):
+            sql = (
+                "INSERT INTO device_federation_outbox"
+                " (destination, stream_id, queued_ts, messages_json)"
+                " VALUES (?,?,?,?)"
+            )
+            rows = []
+            for destination, edu in remote_messages_by_destination.items():
+                edu_json = ujson.dumps(edu)
+                rows.append((destination, stream_id, now_ms, edu_json))
+
+            txn.executemany(sql, rows)
+
+        def add_messages_txn(txn, now_ms, stream_id):
+            self._add_messages_to_local_device_inbox_txn(
+                txn, stream_id, local_messages_by_user_then_device
+            )
+            add_messages_to_device_federation_outbox(now_ms, stream_id)
+
         with self._device_inbox_id_gen.get_next() as stream_id:
+            now_ms = self.clock.time_now_ms()
             yield self.runInteraction(
                 "add_messages_to_device_inbox",
-                self._add_messages_to_device_inbox_txn,
+                add_messages_txn,
+                now_ms,
                 stream_id,
-                messages_by_user_then_device,
             )
 
         defer.returnValue(self._device_inbox_id_gen.get_current_token())
 
-    def _add_messages_to_device_inbox_txn(self, txn, stream_id,
-                                          messages_by_user_then_device):
+    @defer.inlineCallbacks
+    def add_messages_from_remote_to_device_inbox(
+        self, origin, message_id, local_messages_by_user_then_device
+    ):
+        def add_messages_txn(txn, now_ms, stream_id):
+            already_inserted = self._simple_select_one_txn(
+                txn, table="device_federation_inbox",
+                keyvalues={"origin": origin, "message_id": message_id},
+                retcols=("message_id",),
+                allow_none=True,
+            )
+            if already_inserted is not None:
+                return
+
+            self._simple_insert_txn(
+                txn, table="device_federation_inbox",
+                values={
+                    "origin": origin,
+                    "message_id": message_id,
+                    "received_ts": now_ms,
+                },
+            )
+
+            self._add_messages_to_local_device_inbox_txn(
+                txn, stream_id, local_messages_by_user_then_device
+            )
+
+        with self._device_inbox_id_gen.get_next() as stream_id:
+            now_ms = self.clock.time_now_ms()
+            yield self.runInteraction(
+                "add_messages_from_remote_to_device_inbox",
+                add_messages_txn,
+                now_ms,
+                stream_id,
+            )
+
+    def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
+                                                messages_by_user_then_device):
         local_users_and_devices = set()
         for user_id, messages_by_device in messages_by_user_then_device.items():
             devices = messages_by_device.keys()
@@ -177,3 +238,67 @@ class DeviceInboxStore(SQLBaseStore):
 
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
+
+    @defer.inlineCallbacks
+    def get_new_device_messages_for_remote_destination(
+        self, destination, last_stream_id, current_stream_id, limit=100
+    ):
+        """
+        Args:
+            destination(str): The name of the remote server.
+            last_stream_id(int): The last position of the device message stream
+                that the server sent up to.
+            current_stream_id(int): The current position of the device
+                message stream.
+        Returns:
+            Deferred ([dict], int): List of messages for the device and where
+                in the stream the messages got to.
+        """
+        def get_new_messages_for_remote_destination_txn(txn):
+            sql = (
+                "SELECT stream_id, messages_json FROM device_federation_outbox"
+                " WHERE destination = ?"
+                " AND ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (
+                destination, last_stream_id, current_stream_id, limit
+            ))
+            messages = []
+            for row in txn.fetchall():
+                stream_pos = row[0]
+                messages.append(ujson.loads(row[1]))
+            if len(messages) < limit:
+                stream_pos = current_stream_id
+            return (messages, stream_pos)
+
+        return self.runInteraction(
+            "get_new_device_messages_for_remote_destination",
+            get_new_messages_for_remote_destination_txn,
+        )
+
+    @defer.inlineCallbacks
+    def delete_device_messages_for_remote_destination(self, destination,
+                                                      up_to_stream_id):
+        """Used to delete messages when the remote destination acknowledges
+        their receipt.
+
+        Args:
+            destination(str): The destination server_name
+            up_to_stream_id(int): Where to delete messages up to.
+        Returns:
+            A deferred that resolves when the messages have been deleted.
+        """
+        def delete_messages_for_remote_destination_txn(txn):
+            sql = (
+                "DELETE FROM device_federation_outbox"
+                " WHERE destination = ? AND"
+                " AND stream_id <= ?"
+            )
+            txn.execute(sql, (destination, up_to_stream_id))
+
+        return self.runInteraction(
+            "delete_device_messages_for_remote_destination",
+            delete_messages_for_remote_destination_txn
+        )