diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/storage/deviceinbox.py | 54 |
1 files changed, 36 insertions, 18 deletions
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 0d37bb961b..658fbef27b 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -130,19 +130,41 @@ class DeviceInboxStore(SQLBaseStore): def _add_messages_to_local_device_inbox_txn(self, txn, stream_id, messages_by_user_then_device): - local_users_and_devices = set() + local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): + messages_json_for_user = {} devices = messages_by_device.keys() - sql = ( - "SELECT user_id, device_id FROM devices" - " WHERE user_id = ? AND device_id IN (" - + ",".join("?" * len(devices)) - + ")" - ) - # TODO: Maybe this needs to be done in batches if there are - # too many local devices for a given user. - txn.execute(sql, [user_id] + devices) - local_users_and_devices.update(map(tuple, txn.fetchall())) + if len(devices) == 1 and devices[0] == "*": + # Handle wildcard device_ids. + sql = ( + "SELECT device_id FROM devices" + " WHERE user_id = ?" + ) + txn.execute(sql, (user_id,)) + message_json = ujson.dumps(messages_by_device["*"]) + for row in txn.fetchall(): + # Add the message for all devices for this user on this + # server. + device = row[0] + messages_json_for_user[device] = message_json + else: + sql = ( + "SELECT device_id FROM devices" + " WHERE user_id = ? AND device_id IN (" + + ",".join("?" * len(devices)) + + ")" + ) + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + txn.execute(sql, [user_id] + devices) + for row in txn.fetchall(): + # Only insert into the local inbox if the device exists on + # this server + device = row[0] + message_json = ujson.dumps(messages_by_device[device]) + messages_json_for_user[device] = message_json + + local_by_user_then_device[user_id] = messages_json_for_user sql = ( "INSERT INTO device_inbox" @@ -150,13 +172,9 @@ class DeviceInboxStore(SQLBaseStore): " VALUES (?,?,?,?)" ) rows = [] - for user_id, messages_by_device in messages_by_user_then_device.items(): - for device_id, message in messages_by_device.items(): - message_json = ujson.dumps(message) - # Only insert into the local inbox if the device exists on - # this server - if (user_id, device_id) in local_users_and_devices: - rows.append((user_id, device_id, stream_id, message_json)) + for user_id, messages_by_device in local_by_user_then_device.items(): + for device_id, message_json in messages_by_device.items(): + rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) |