summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-11-17 15:46:44 +0000
committerErik Johnston <erik@matrix.org>2016-11-17 15:48:04 +0000
commitf8ee66250a16cb9dd3af01fb1150ff18cfebbc39 (patch)
tree9920bd4e8164f705b4e27c714d6c053082dcf7a5
parentHook up the send queue and create a federation sender worker (diff)
downloadsynapse-f8ee66250a16cb9dd3af01fb1150ff18cfebbc39.tar.xz
Handle sending events and device messages over federation
-rw-r--r--synapse/app/federation_sender.py31
-rw-r--r--synapse/federation/send_queue.py38
-rw-r--r--synapse/federation/transaction_queue.py32
-rw-r--r--synapse/handlers/message.py13
-rw-r--r--synapse/notifier.py2
-rw-r--r--synapse/replication/resource.py2
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py15
-rw-r--r--synapse/replication/slave/storage/events.py11
-rw-r--r--synapse/replication/slave/storage/transactions.py4
-rw-r--r--synapse/storage/deviceinbox.py26
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/delta/39/device_federation_stream_idx.sql16
-rw-r--r--synapse/storage/stream.py31
-rw-r--r--synapse/util/jsonobject.py17
14 files changed, 185 insertions, 55 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 7a4fec4a66..32113c175c 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -127,13 +127,6 @@ class FederationSenderServer(HomeServer):
         replication_url = self.config.worker_replication_url
         send_handler = self._get_send_handler()
 
-        def replicate(results):
-            stream = results.get("events")
-            if stream:
-                # max_stream_id = stream["position"]
-                # TODO
-                pass
-
         while True:
             try:
                 args = store.stream_positions()
@@ -142,7 +135,6 @@ class FederationSenderServer(HomeServer):
                 result = yield http_client.get_json(replication_url, args=args)
                 yield store.process_replication(result)
                 send_handler.process_replication(result)
-                replicate(result)
             except:
                 logger.exception("Error replicating from %r", replication_url)
                 yield sleep(30)
@@ -242,16 +234,17 @@ class FederationSenderHandler(object):
         return {"federation": self._latest_room_serial}
 
     def process_replication(self, result):
-        stream = result.get("federation")
-        if stream:
-            self._latest_room_serial = int(stream["position"])
+        fed_stream = result.get("federation")
+        if fed_stream:
+            self._latest_room_serial = int(fed_stream["position"])
 
             presence_to_send = {}
             keyed_edus = {}
             edus = {}
             failures = {}
+            device_destinations = set()
 
-            for row in stream["rows"]:
+            for row in fed_stream["rows"]:
                 position, typ, content_js = row
                 content = json.loads(content_js)
 
@@ -264,7 +257,9 @@ class FederationSenderHandler(object):
                     key = content["key"]
                     edu = Edu(**content["edu"])
 
-                    keyed_edus.setdefault(edu.destination, {})[key] = edu
+                    keyed_edus.setdefault(
+                        edu.destination, {}
+                    )[(edu.destination, tuple(key))] = edu
                 elif typ == send_queue.EDU_TYPE:
                     edu = Edu(**content)
 
@@ -274,6 +269,8 @@ class FederationSenderHandler(object):
                     failure = content["failure"]
 
                     failures.setdefault(destination, []).append(failure)
+                elif typ == send_queue.DEVICE_MESSAGE_TYPE:
+                    device_destinations.add(content["destination"])
                 else:
                     raise Exception("Unrecognised federation type: %r", typ)
 
@@ -296,6 +293,14 @@ class FederationSenderHandler(object):
                 for failure in failure_list:
                     self.federation_sender.send_failure(destination, failure)
 
+            for destination in device_destinations:
+                self.federation_sender.send_device_messages(destination)
+
+        event_stream = result.get("events")
+        if event_stream:
+            latest_pos = event_stream["position"]
+            self.federation_sender.notify_new_events(latest_pos)
+
 
 if __name__ == '__main__':
     with LoggingContext("main"):
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index d439be050a..3fc625c4dd 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -23,11 +23,13 @@ PRESENCE_TYPE = "p"
 KEYED_EDU_TYPE = "k"
 EDU_TYPE = "e"
 FAILURE_TYPE = "f"
+DEVICE_MESSAGE_TYPE = "d"
 
 
 class FederationRemoteSendQueue(object):
 
     def __init__(self, hs):
+        self.server_name = hs.hostname
         self.clock = hs.get_clock()
 
         # TODO: Add metrics for size of lists below
@@ -45,6 +47,8 @@ class FederationRemoteSendQueue(object):
         self.pos = 1
         self.pos_time = sorteddict()
 
+        self.device_messages = sorteddict()
+
         self.clock.looping_call(self._clear_queue, 30 * 1000)
 
     def _next_pos(self):
@@ -111,6 +115,15 @@ class FederationRemoteSendQueue(object):
         for key in keys[:i]:
             del self.failures[key]
 
+        # Delete things out of device map
+        keys = self.device_messages.keys()
+        i = keys.bisect_left(position_to_delete)
+        for key in keys[:i]:
+            del self.device_messages[key]
+
+    def notify_new_events(self, current_id):
+        pass
+
     def send_edu(self, destination, edu_type, content, key=None):
         pos = self._next_pos()
 
@@ -122,6 +135,7 @@ class FederationRemoteSendQueue(object):
         )
 
         if key:
+            assert isinstance(key, tuple)
             self.keyed_edu[(destination, key)] = edu
             self.keyed_edu_changed[pos] = (destination, key)
         else:
@@ -148,9 +162,9 @@ class FederationRemoteSendQueue(object):
         # This gets sent down a separate path
         pass
 
-    def notify_new_device_message(self, destination):
-        # TODO
-        pass
+    def send_device_messages(self, destination):
+        pos = self._next_pos()
+        self.device_messages[pos] = destination
 
     def get_current_token(self):
         return self.pos - 1
@@ -188,11 +202,11 @@ class FederationRemoteSendQueue(object):
         i = keys.bisect_right(token)
         keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
 
-        for (pos, edu_key) in keyed_edus:
+        for (pos, (destination, edu_key)) in keyed_edus:
             rows.append(
                 (pos, KEYED_EDU_TYPE, ujson.dumps({
                     "key": edu_key,
-                    "edu": self.keyed_edu[edu_key].get_dict(),
+                    "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
                 }))
             )
 
@@ -202,7 +216,7 @@ class FederationRemoteSendQueue(object):
         edus = set((k, self.edus[k]) for k in keys[i:])
 
         for (pos, edu) in edus:
-            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_dict())))
+            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
 
         # Fetch changed failures
         keys = self.failures.keys()
@@ -210,11 +224,21 @@ class FederationRemoteSendQueue(object):
         failures = set((k, self.failures[k]) for k in keys[i:])
 
         for (pos, (destination, failure)) in failures:
-            rows.append((pos, None, FAILURE_TYPE, ujson.dumps({
+            rows.append((pos, FAILURE_TYPE, ujson.dumps({
                 "destination": destination,
                 "failure": failure,
             })))
 
+        # Fetch changed device messages
+        keys = self.device_messages.keys()
+        i = keys.bisect_right(token)
+        device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+
+        for (pos, destination) in device_messages:
+            rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
+                "destination": destination,
+            })))
+
         # Sort rows based on pos
         rows.sort()
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 5d4f244377..aa664beead 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,6 +26,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.types import get_domain_from_id
 from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
@@ -54,6 +55,7 @@ class TransactionQueue(object):
         self.server_name = hs.hostname
 
         self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
         self.transaction_actions = TransactionActions(self.store)
 
         self.transport_layer = hs.get_federation_transport_client()
@@ -103,6 +105,9 @@ class TransactionQueue(object):
 
         self._order = 1
 
+        self._is_processing = False
+        self._last_token = 0
+
     def can_send_to(self, destination):
         """Can we send messages to the given server?
 
@@ -123,6 +128,33 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
+    @defer.inlineCallbacks
+    def notify_new_events(self, current_id):
+        if self._is_processing:
+            return
+
+        try:
+            self._is_processing = True
+            while True:
+                self._last_token, events = yield self.store.get_all_new_events_stream(
+                    self._last_token, current_id, limit=20,
+                )
+
+                if not events:
+                    break
+
+                for event in events:
+                    users_in_room = yield self.state.get_current_user_in_room(
+                        event.room_id, latest_event_ids=[event.event_id],
+                    )
+
+                    destinations = [
+                        get_domain_from_id(user_id) for user_id in users_in_room
+                    ]
+                    self.send_pdu(event, destinations)
+        finally:
+            self._is_processing = False
+
     def send_pdu(self, pdu, destinations):
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 81df45177a..fd09397226 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,7 +22,7 @@ from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.push.action_generator import ActionGenerator
 from synapse.types import (
-    UserID, RoomAlias, RoomStreamToken, get_domain_from_id
+    UserID, RoomAlias, RoomStreamToken,
 )
 from synapse.util.async import run_on_reactor, ReadWriteLock
 from synapse.util.logcontext import preserve_fn
@@ -599,13 +599,6 @@ class MessageHandler(BaseHandler):
             event_stream_id, max_stream_id
         )
 
-        users_in_room = yield self.store.get_joined_users_from_context(event, context)
-
-        destinations = [
-            get_domain_from_id(user_id) for user_id in users_in_room
-            if not self.hs.is_mine_id(user_id)
-        ]
-
         @defer.inlineCallbacks
         def _notify():
             yield run_on_reactor()
@@ -618,7 +611,3 @@ class MessageHandler(BaseHandler):
 
         # If invite, remove room_state from unsigned before sending.
         event.unsigned.pop("invite_room_state", None)
-
-        preserve_fn(federation_handler.handle_new_event)(
-            event, destinations=destinations,
-        )
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 48653ae843..d528d1c1e0 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -143,6 +143,7 @@ class Notifier(object):
 
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
+        self.federation_sender = hs.get_federation_sender()
         self.state_handler = hs.get_state_handler()
 
         self.clock.looping_call(
@@ -219,6 +220,7 @@ class Notifier(object):
         """Notify any user streams that are interested in this room event"""
         # poke any interested application service.
         self.appservice_handler.notify_interested_services(room_stream_id)
+        self.federation_sender.notify_new_events(room_stream_id)
 
         if event.type == EventTypes.Member and event.membership == Membership.JOIN:
             self._user_joined_room(event.state_key, event.room_id)
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index a77312ae34..e708811326 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -453,7 +453,7 @@ class ReplicationResource(Resource):
             )
             upto_token = _position_from_rows(to_device_rows, current_position)
             writer.write_header_and_rows("to_device", to_device_rows, (
-                "position", "user_id", "device_id", "message_json"
+                "position", "entity",
             ), position=upto_token)
 
     @defer.inlineCallbacks
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 373212d42d..cc860f9f9b 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -38,6 +38,7 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
     get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
     get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
     delete_messages_for_device = DataStore.delete_messages_for_device.__func__
+    delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
 
     def stream_positions(self):
         result = super(SlavedDeviceInboxStore, self).stream_positions()
@@ -50,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
             self._device_inbox_id_gen.advance(int(stream["position"]))
             for row in stream["rows"]:
                 stream_id = row[0]
-                user_id = row[1]
-                self._device_inbox_stream_cache.entity_has_changed(
-                    user_id, stream_id
-                )
+                entity = row[1]
+
+                if entity.startswith("@"):
+                    self._device_inbox_stream_cache.entity_has_changed(
+                        entity, stream_id
+                    )
+                else:
+                    self._device_federation_outbox_stream_cache.entity_has_changed(
+                        entity, stream_id
+                    )
 
         return super(SlavedDeviceInboxStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 0c26e96e98..ef8713b55d 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 import ujson as json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
 
 # So, um, we want to borrow a load of functions intended for reading from
 # a DataStore, but we don't want to take functions that either write to the
@@ -180,6 +185,8 @@ class SlavedEventStore(BaseSlavedStore):
         EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
     )
 
+    get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
+
     def stream_positions(self):
         result = super(SlavedEventStore, self).stream_positions()
         result["events"] = self._stream_id_gen.get_current_token()
@@ -194,6 +201,10 @@ class SlavedEventStore(BaseSlavedStore):
         stream = result.get("events")
         if stream:
             self._stream_id_gen.advance(int(stream["position"]))
+
+            if stream["rows"]:
+                logger.info("Got %d event rows", len(stream["rows"]))
+
             for row in stream["rows"]:
                 self._process_replication_row(
                     row, backfilled=False, state_resets=state_resets
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index c459301b76..d92cea4ab1 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -25,8 +25,8 @@ class TransactionStore(BaseSlavedStore):
     ].orig
     _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
 
-    def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
-        return []
+    prep_send_transaction = DataStore.prep_send_transaction.__func__
+    delivered_txn = DataStore.delivered_txn.__func__
 
     # For now, don't record the destination rety timings
     def set_destination_retry_timings(*args, **kwargs):
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index f640e73714..87398d60bc 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -269,27 +269,29 @@ class DeviceInboxStore(SQLBaseStore):
             return defer.succeed([])
 
         def get_all_new_device_messages_txn(txn):
+            # We limit like this as we might have multiple rows per stream_id, and
+            # we want to make sure we always get all entries for any stream_id
+            # we return.
+            upper_pos = min(current_pos, last_pos + limit)
             sql = (
-                "SELECT stream_id FROM device_inbox"
+                "SELECT stream_id, user_id"
+                " FROM device_inbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY stream_id"
                 " ORDER BY stream_id ASC"
-                " LIMIT ?"
             )
-            txn.execute(sql, (last_pos, current_pos, limit))
-            stream_ids = txn.fetchall()
-            if not stream_ids:
-                return []
-            max_stream_id_in_limit = stream_ids[-1]
+            txn.execute(sql, (last_pos, upper_pos))
+            rows = txn.fetchall()
 
             sql = (
-                "SELECT stream_id, user_id, device_id, message_json"
-                " FROM device_inbox"
+                "SELECT stream_id, destination"
+                " FROM device_federation_outbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC"
             )
-            txn.execute(sql, (last_pos, max_stream_id_in_limit))
-            return txn.fetchall()
+            txn.execute(sql, (last_pos, upper_pos))
+            rows.extend(txn.fetchall())
+
+            return rows
 
         return self.runInteraction(
             "get_all_new_device_messages", get_all_new_device_messages_txn
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6576a30098..e46ae6502e 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 38
+SCHEMA_VERSION = 39
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql
new file mode 100644
index 0000000000..00be801e90
--- /dev/null
+++ b/synapse/storage/schema/delta/39/device_federation_stream_idx.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX device_federation_outbox_id ON device_federation_outbox(stream_id);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 888b1cb35d..f34cb78f9a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -765,3 +765,34 @@ class StreamStore(SQLBaseStore):
                 "token": end_token,
             },
         }
+
+    @defer.inlineCallbacks
+    def get_all_new_events_stream(self, from_id, current_id, limit):
+        """Get all new events"""
+
+        def get_all_new_events_stream_txn(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id"
+                " FROM events AS e"
+                " WHERE"
+                " ? < e.stream_ordering AND e.stream_ordering <= ?"
+                " ORDER BY e.stream_ordering ASC"
+                " LIMIT ?"
+            )
+
+            txn.execute(sql, (from_id, current_id, limit))
+            rows = txn.fetchall()
+
+            upper_bound = current_id
+            if len(rows) == limit:
+                upper_bound = rows[-1][0]
+
+            return upper_bound, [row[1] for row in rows]
+
+        upper_bound, event_ids = yield self.runInteraction(
+            "get_all_new_events_stream", get_all_new_events_stream_txn,
+        )
+
+        events = yield self._get_events(event_ids)
+
+        defer.returnValue((upper_bound, events))
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 3fd5c3d9fd..d668e5a6b8 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -76,15 +76,26 @@ class JsonEncodedObject(object):
         d.update(self.unrecognized_keys)
         return d
 
+    def get_internal_dict(self):
+        d = {
+            k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
+            if k in self.valid_keys
+        }
+        d.update(self.unrecognized_keys)
+        return d
+
     def __str__(self):
         return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
 
 
-def _encode(obj):
+def _encode(obj, internal=False):
     if type(obj) is list:
-        return [_encode(o) for o in obj]
+        return [_encode(o, internal=internal) for o in obj]
 
     if isinstance(obj, JsonEncodedObject):
-        return obj.get_dict()
+        if internal:
+            return obj.get_internal_dict()
+        else:
+            return obj.get_dict()
 
     return obj