summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2016-09-08 15:13:05 +0100
committerMark Haines <mark.haines@matrix.org>2016-09-08 15:13:05 +0100
commita1c8f268e5948d6466d64ef983b98fce287ec907 (patch)
treecc4afdcd86aeac038179c8aae7cbe214f37d55b8 /synapse/storage
parentMerge pull request #1074 from matrix-org/markjh/direct_to_device_federation (diff)
downloadsynapse-a1c8f268e5948d6466d64ef983b98fce287ec907.tar.xz
Support wildcard device_ids for direct to device messages
Diffstat (limited to '')
-rw-r--r--synapse/storage/deviceinbox.py54
1 files changed, 36 insertions, 18 deletions
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 0d37bb961b..658fbef27b 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -130,19 +130,41 @@ class DeviceInboxStore(SQLBaseStore):
 
     def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
                                                 messages_by_user_then_device):
-        local_users_and_devices = set()
+        local_by_user_then_device = {}
         for user_id, messages_by_device in messages_by_user_then_device.items():
+            messages_json_for_user = {}
             devices = messages_by_device.keys()
-            sql = (
-                "SELECT user_id, device_id FROM devices"
-                " WHERE user_id = ? AND device_id IN ("
-                + ",".join("?" * len(devices))
-                + ")"
-            )
-            # TODO: Maybe this needs to be done in batches if there are
-            # too many local devices for a given user.
-            txn.execute(sql, [user_id] + devices)
-            local_users_and_devices.update(map(tuple, txn.fetchall()))
+            if len(devices) == 1 and devices[0] == "*":
+                # Handle wildcard device_ids.
+                sql = (
+                    "SELECT device_id FROM devices"
+                    " WHERE user_id = ?"
+                )
+                txn.execute(sql, (user_id,))
+                message_json = ujson.dumps(messages_by_device["*"])
+                for row in txn.fetchall():
+                    # Add the message for all devices for this user on this
+                    # server.
+                    device = row[0]
+                    messages_json_for_user[device] = message_json
+            else:
+                sql = (
+                    "SELECT device_id FROM devices"
+                    " WHERE user_id = ? AND device_id IN ("
+                    + ",".join("?" * len(devices))
+                    + ")"
+                )
+                # TODO: Maybe this needs to be done in batches if there are
+                # too many local devices for a given user.
+                txn.execute(sql, [user_id] + devices)
+                for row in txn.fetchall():
+                    # Only insert into the local inbox if the device exists on
+                    # this server
+                    device = row[0]
+                    message_json = ujson.dumps(messages_by_device[device])
+                    messages_json_for_user[device] = message_json
+
+            local_by_user_then_device[user_id] = messages_json_for_user
 
         sql = (
             "INSERT INTO device_inbox"
@@ -150,13 +172,9 @@ class DeviceInboxStore(SQLBaseStore):
             " VALUES (?,?,?,?)"
         )
         rows = []
-        for user_id, messages_by_device in messages_by_user_then_device.items():
-            for device_id, message in messages_by_device.items():
-                message_json = ujson.dumps(message)
-                # Only insert into the local inbox if the device exists on
-                # this server
-                if (user_id, device_id) in local_users_and_devices:
-                    rows.append((user_id, device_id, stream_id, message_json))
+        for user_id, messages_by_device in local_by_user_then_device.items():
+            for device_id, message_json in messages_by_device.items():
+                rows.append((user_id, device_id, stream_id, message_json))
 
         txn.executemany(sql, rows)