summary refs log tree commit diff
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2016-09-06 18:16:20 +0100
committerMark Haines <mark.haines@matrix.org>2016-09-06 18:16:20 +0100
commitd4a35ada28302e096efd42e1a2a28542ed7ebd6f (patch)
tree679f1bf11e23af751c074728bbf1ebe8192c8c3b
parentAdd storage methods for federated device messages (diff)
downloadsynapse-d4a35ada28302e096efd42e1a2a28542ed7ebd6f.tar.xz
Send device messages over federation
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/transaction_queue.py43
-rw-r--r--synapse/handlers/devicemessage.py121
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py33
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/deviceinbox.py19
-rw-r--r--synapse/storage/schema/delta/34/device_outbox.sql4
7 files changed, 179 insertions, 48 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5621655098..3fa7b2315c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -188,7 +188,7 @@ class FederationServer(FederationBase):
             except SynapseError as e:
                 logger.info("Failed to handle edu %r: %r", edu_type, e)
             except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type, e)
+                logger.exception("Failed to handle edu %r", edu_type)
         else:
             logger.warn("Received EDU of type %s with no handler", edu_type)
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index cb2ef0210c..5e86141f86 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -17,7 +17,7 @@
 from twisted.internet import defer
 
 from .persistence import TransactionActions
-from .units import Transaction
+from .units import Transaction, Edu
 
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
@@ -187,6 +187,24 @@ class TransactionQueue(object):
                 destination, pending_pdus, pending_edus, pending_failures
             )
 
+    @defer.inlineCallbacks
+    def _get_new_device_messages(self, destination):
+        last_device_stream_id = 0
+        to_device_stream_id = self.store.get_to_device_stream_token()
+        contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
+            destination, last_device_stream_id, to_device_stream_id
+        )
+        edus = [
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.direct_to_device",
+                content=content,
+            )
+            for content in contents
+        ]
+        defer.returnValue((edus, stream_id))
+
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
@@ -211,13 +229,19 @@ class TransactionQueue(object):
                     self.store,
                 )
 
+                device_message_edus, device_stream_id = (
+                    yield self._get_new_device_messages(destination)
+                )
+
+                edus.extend(device_message_edus)
+
                 logger.debug(
                     "TX [%s] {%s} Attempting new transaction"
                     " (pdus: %d, edus: %d, failures: %d)",
                     destination, txn_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures)
+                    len(pdus),
+                    len(edus),
+                    len(failures)
                 )
 
                 logger.debug("TX [%s] Persisting transaction...", destination)
@@ -242,9 +266,9 @@ class TransactionQueue(object):
                     " (PDUs: %d, EDUs: %d, failures: %d)",
                     destination, txn_id,
                     transaction.transaction_id,
-                    len(pending_pdus),
-                    len(pending_edus),
-                    len(pending_failures),
+                    len(pdus),
+                    len(edus),
+                    len(failures),
                 )
 
                 with limiter:
@@ -299,6 +323,11 @@ class TransactionQueue(object):
                         logger.info(
                             "Failed to send event %s to %s", p.event_id, destination
                         )
+                else:
+                    # Remove the acknowledged device messages from the database
+                    yield self.store.delete_device_msgs_for_remote(
+                        destination, device_stream_id
+                    )
             except NotRetryingDestination:
                 logger.info(
                     "TX [%s] not ready for retry yet - "
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
new file mode 100644
index 0000000000..7e59c0d487
--- /dev/null
+++ b/synapse/handlers/devicemessage.py
@@ -0,0 +1,121 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.types import get_domain_from_id
+from synapse.util.stringutils import random_string
+
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceMessageHandler(object):
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+        self.is_mine_id = hs.is_mine_id
+        self.federation = hs.get_replication_layer()
+
+        self.federation.register_edu_handler(
+            "m.direct_to_device", self.on_direct_to_device_edu
+        )
+
+    @defer.inlineCallbacks
+    def on_direct_to_device_edu(self, origin, content):
+        local_messages = {}
+        sender_user_id = content["sender"]
+        if origin != get_domain_from_id(sender_user_id):
+            logger.warn(
+                "Dropping device message from %r with spoofed sender %r",
+                origin, sender_user_id
+            )
+        message_type = content["type"]
+        message_id = content["message_id"]
+        for user_id, by_device in content["messages"].items():
+            messages_by_device = {
+                device_id: {
+                    "content": message_content,
+                    "type": message_type,
+                    "sender": sender_user_id,
+                }
+                for device_id, message_content in by_device.items()
+            }
+            if messages_by_device:
+                local_messages[user_id] = messages_by_device
+
+        stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
+            origin, message_id, local_messages
+        )
+
+        self.notifier.on_new_event(
+            "to_device_key", stream_id, users=local_messages.keys()
+        )
+
+    @defer.inlineCallbacks
+    def send_device_message(self, sender_user_id, message_type, messages):
+
+        local_messages = {}
+        remote_messages = {}
+        for user_id, by_device in messages.items():
+            if self.is_mine_id(user_id):
+                messages_by_device = {
+                    device_id: {
+                        "content": message_content,
+                        "type": message_type,
+                        "sender": sender_user_id,
+                    }
+                    for device_id, message_content in by_device.items()
+                }
+                if messages_by_device:
+                    local_messages[user_id] = messages_by_device
+            else:
+                destination = get_domain_from_id(user_id)
+                remote_messages.setdefault(destination, {})[user_id] = by_device
+
+        message_id = random_string(16)
+
+        remote_edu_contents = {}
+        for destination, messages in remote_messages.items():
+            remote_edu_contents[destination] = {
+                "messages": messages,
+                "sender": sender_user_id,
+                "type": message_type,
+                "message_id": message_id,
+            }
+
+        stream_id = yield self.store.add_messages_to_device_inbox(
+            local_messages, remote_edu_contents
+        )
+
+        self.notifier.on_new_event(
+            "to_device_key", stream_id, users=local_messages.keys()
+        )
+
+        for destination in remote_messages.keys():
+            # Hack to send make synapse send a federation transaction
+            # to the remote servers.
+            self.federation.send_edu(
+                destination=destination,
+                edu_type="m.ping",
+                content={},
+            )
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 9c10a99acf..5975164b37 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -16,10 +16,11 @@
 import logging
 
 from twisted.internet import defer
-from synapse.http.servlet import parse_json_object_from_request
 
 from synapse.http import servlet
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.rest.client.v1.transactions import HttpTransactionStore
+
 from ._base import client_v2_patterns
 
 logger = logging.getLogger(__name__)
@@ -39,10 +40,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         super(SendToDeviceRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self.is_mine_id = hs.is_mine_id
         self.txns = HttpTransactionStore()
+        self.device_message_handler = hs.get_device_message_handler()
 
     @defer.inlineCallbacks
     def on_PUT(self, request, message_type, txn_id):
@@ -57,28 +56,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
 
         content = parse_json_object_from_request(request)
 
-        # TODO: Prod the notifier to wake up sync streams.
-        # TODO: Implement replication for the messages.
-        # TODO: Send the messages to remote servers if needed.
-
-        local_messages = {}
-        for user_id, by_device in content["messages"].items():
-            if self.is_mine_id(user_id):
-                messages_by_device = {
-                    device_id: {
-                        "content": message_content,
-                        "type": message_type,
-                        "sender": requester.user.to_string(),
-                    }
-                    for device_id, message_content in by_device.items()
-                }
-                if messages_by_device:
-                    local_messages[user_id] = messages_by_device
-
-        stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
-
-        self.notifier.on_new_event(
-            "to_device_key", stream_id, users=local_messages.keys()
+        sender_user_id = requester.user.to_string()
+
+        yield self.device_message_handler.send_device_message(
+            sender_user_id, message_type, content["messages"]
         )
 
         response = (200, {})
diff --git a/synapse/server.py b/synapse/server.py
index af3246504b..f516f08167 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -35,6 +35,7 @@ from synapse.federation import initialize_http_replication
 from synapse.handlers import Handlers
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.handlers.auth import AuthHandler
+from synapse.handlers.devicemessage import DeviceMessageHandler
 from synapse.handlers.device import DeviceHandler
 from synapse.handlers.e2e_keys import E2eKeysHandler
 from synapse.handlers.presence import PresenceHandler
@@ -100,6 +101,7 @@ class HomeServer(object):
         'application_service_api',
         'application_service_scheduler',
         'application_service_handler',
+        'device_message_handler',
         'notifier',
         'distributor',
         'client_resource',
@@ -205,6 +207,9 @@ class HomeServer(object):
     def build_device_handler(self):
         return DeviceHandler(self)
 
+    def build_device_message_handler(self):
+        return DeviceMessageHandler(self)
+
     def build_e2e_keys_handler(self):
         return E2eKeysHandler(self)
 
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 988577a334..d9f91ccc4e 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -59,10 +59,10 @@ class DeviceInboxStore(SQLBaseStore):
             self._add_messages_to_local_device_inbox_txn(
                 txn, stream_id, local_messages_by_user_then_device
             )
-            add_messages_to_device_federation_outbox(now_ms, stream_id)
+            add_messages_to_device_federation_outbox(txn, now_ms, stream_id)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
-            now_ms = self.clock.time_now_ms()
+            now_ms = self.clock.time_msec()
             yield self.runInteraction(
                 "add_messages_to_device_inbox",
                 add_messages_txn,
@@ -100,7 +100,7 @@ class DeviceInboxStore(SQLBaseStore):
             )
 
         with self._device_inbox_id_gen.get_next() as stream_id:
-            now_ms = self.clock.time_now_ms()
+            now_ms = self.clock.time_msec()
             yield self.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,
@@ -239,8 +239,7 @@ class DeviceInboxStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def get_new_device_messages_for_remote_destination(
+    def get_new_device_msgs_for_remote(
         self, destination, last_stream_id, current_stream_id, limit=100
     ):
         """
@@ -274,13 +273,11 @@ class DeviceInboxStore(SQLBaseStore):
             return (messages, stream_pos)
 
         return self.runInteraction(
-            "get_new_device_messages_for_remote_destination",
+            "get_new_device_msgs_for_remote",
             get_new_messages_for_remote_destination_txn,
         )
 
-    @defer.inlineCallbacks
-    def delete_device_messages_for_remote_destination(self, destination,
-                                                      up_to_stream_id):
+    def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
         """Used to delete messages when the remote destination acknowledges
         their receipt.
 
@@ -293,12 +290,12 @@ class DeviceInboxStore(SQLBaseStore):
         def delete_messages_for_remote_destination_txn(txn):
             sql = (
                 "DELETE FROM device_federation_outbox"
-                " WHERE destination = ? AND"
+                " WHERE destination = ?"
                 " AND stream_id <= ?"
             )
             txn.execute(sql, (destination, up_to_stream_id))
 
         return self.runInteraction(
-            "delete_device_messages_for_remote_destination",
+            "delete_device_msgs_for_remote",
             delete_messages_for_remote_destination_txn
         )
diff --git a/synapse/storage/schema/delta/34/device_outbox.sql b/synapse/storage/schema/delta/34/device_outbox.sql
index a319f73e47..e87066d9a1 100644
--- a/synapse/storage/schema/delta/34/device_outbox.sql
+++ b/synapse/storage/schema/delta/34/device_outbox.sql
@@ -16,9 +16,7 @@
 CREATE TABLE device_federation_outbox (
     destination TEXT NOT NULL,
     stream_id BIGINT NOT NULL,
-    sender TEXT NOT NULL,
-    message_id TEXT NOT NULL,
-    sent_ts BIGINT NOT NULL,
+    queued_ts BIGINT NOT NULL,
     messages_json TEXT NOT NULL
 );