summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/send_queue.py38
-rw-r--r--synapse/federation/transaction_queue.py32
2 files changed, 63 insertions, 7 deletions
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index d439be050a..3fc625c4dd 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -23,11 +23,13 @@ PRESENCE_TYPE = "p"
 KEYED_EDU_TYPE = "k"
 EDU_TYPE = "e"
 FAILURE_TYPE = "f"
+DEVICE_MESSAGE_TYPE = "d"
 
 
 class FederationRemoteSendQueue(object):
 
     def __init__(self, hs):
+        self.server_name = hs.hostname
         self.clock = hs.get_clock()
 
         # TODO: Add metrics for size of lists below
@@ -45,6 +47,8 @@ class FederationRemoteSendQueue(object):
         self.pos = 1
         self.pos_time = sorteddict()
 
+        self.device_messages = sorteddict()
+
         self.clock.looping_call(self._clear_queue, 30 * 1000)
 
     def _next_pos(self):
@@ -111,6 +115,15 @@ class FederationRemoteSendQueue(object):
         for key in keys[:i]:
             del self.failures[key]
 
+        # Delete things out of device map
+        keys = self.device_messages.keys()
+        i = keys.bisect_left(position_to_delete)
+        for key in keys[:i]:
+            del self.device_messages[key]
+
+    def notify_new_events(self, current_id):
+        pass
+
     def send_edu(self, destination, edu_type, content, key=None):
         pos = self._next_pos()
 
@@ -122,6 +135,7 @@ class FederationRemoteSendQueue(object):
         )
 
         if key:
+            assert isinstance(key, tuple)
             self.keyed_edu[(destination, key)] = edu
             self.keyed_edu_changed[pos] = (destination, key)
         else:
@@ -148,9 +162,9 @@ class FederationRemoteSendQueue(object):
         # This gets sent down a separate path
         pass
 
-    def notify_new_device_message(self, destination):
-        # TODO
-        pass
+    def send_device_messages(self, destination):
+        pos = self._next_pos()
+        self.device_messages[pos] = destination
 
     def get_current_token(self):
         return self.pos - 1
@@ -188,11 +202,11 @@ class FederationRemoteSendQueue(object):
         i = keys.bisect_right(token)
         keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
 
-        for (pos, edu_key) in keyed_edus:
+        for (pos, (destination, edu_key)) in keyed_edus:
             rows.append(
                 (pos, KEYED_EDU_TYPE, ujson.dumps({
                     "key": edu_key,
-                    "edu": self.keyed_edu[edu_key].get_dict(),
+                    "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
                 }))
             )
 
@@ -202,7 +216,7 @@ class FederationRemoteSendQueue(object):
         edus = set((k, self.edus[k]) for k in keys[i:])
 
         for (pos, edu) in edus:
-            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_dict())))
+            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
 
         # Fetch changed failures
         keys = self.failures.keys()
@@ -210,11 +224,21 @@ class FederationRemoteSendQueue(object):
         failures = set((k, self.failures[k]) for k in keys[i:])
 
         for (pos, (destination, failure)) in failures:
-            rows.append((pos, None, FAILURE_TYPE, ujson.dumps({
+            rows.append((pos, FAILURE_TYPE, ujson.dumps({
                 "destination": destination,
                 "failure": failure,
             })))
 
+        # Fetch changed device messages
+        keys = self.device_messages.keys()
+        i = keys.bisect_right(token)
+        device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+
+        for (pos, destination) in device_messages:
+            rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
+                "destination": destination,
+            })))
+
         # Sort rows based on pos
         rows.sort()
 
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 5d4f244377..aa664beead 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,6 +26,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.types import get_domain_from_id
 from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
@@ -54,6 +55,7 @@ class TransactionQueue(object):
         self.server_name = hs.hostname
 
         self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
         self.transaction_actions = TransactionActions(self.store)
 
         self.transport_layer = hs.get_federation_transport_client()
@@ -103,6 +105,9 @@ class TransactionQueue(object):
 
         self._order = 1
 
+        self._is_processing = False
+        self._last_token = 0
+
     def can_send_to(self, destination):
         """Can we send messages to the given server?
 
@@ -123,6 +128,33 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
+    @defer.inlineCallbacks
+    def notify_new_events(self, current_id):
+        if self._is_processing:
+            return
+
+        try:
+            self._is_processing = True
+            while True:
+                self._last_token, events = yield self.store.get_all_new_events_stream(
+                    self._last_token, current_id, limit=20,
+                )
+
+                if not events:
+                    break
+
+                for event in events:
+                    users_in_room = yield self.state.get_current_user_in_room(
+                        event.room_id, latest_event_ids=[event.event_id],
+                    )
+
+                    destinations = [
+                        get_domain_from_id(user_id) for user_id in users_in_room
+                    ]
+                    self.send_pdu(event, destinations)
+        finally:
+            self._is_processing = False
+
     def send_pdu(self, pdu, destinations):
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus