diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index bf6b1c1535..8758af4ca1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -444,6 +444,16 @@ class RoomListHandler(BaseHandler):
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
+ def get_remote_public_room_list(self, server_name):
+ res = yield self.hs.get_replication_layer().get_public_rooms(
+ [server_name]
+ )
+
+ if server_name not in res:
+ raise SynapseError(404, "Server not found")
+ defer.returnValue(res[server_name])
+
+ @defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0d81757010..3c933f1620 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import parse_json_object_from_request, parse_string
import logging
import urllib
@@ -295,15 +295,26 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
+ server = parse_string(request, "server", default=None)
+
try:
yield self.auth.get_user_by_req(request)
- except AuthError:
- # This endpoint isn't authed, but its useful to know who's hitting
- # it if they *do* supply an access token
- pass
+ except AuthError as e:
+ # We allow people to not be authed if they're just looking at our
+ # room list, but require auth when we proxy the request.
+ # In both cases we call the auth function, as that has the side
+ # effect of logging who issued this request if an access token was
+ # provided.
+ if server:
+ raise e
+ else:
+ pass
handler = self.hs.get_room_list_handler()
- data = yield handler.get_aggregated_public_room_list()
+ if server:
+ data = yield handler.get_remote_public_room_list(server)
+ else:
+ data = yield handler.get_aggregated_public_room_list()
defer.returnValue((200, data))
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 0d37bb961b..658fbef27b 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -130,19 +130,41 @@ class DeviceInboxStore(SQLBaseStore):
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
messages_by_user_then_device):
- local_users_and_devices = set()
+ local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
+ messages_json_for_user = {}
devices = messages_by_device.keys()
- sql = (
- "SELECT user_id, device_id FROM devices"
- " WHERE user_id = ? AND device_id IN ("
- + ",".join("?" * len(devices))
- + ")"
- )
- # TODO: Maybe this needs to be done in batches if there are
- # too many local devices for a given user.
- txn.execute(sql, [user_id] + devices)
- local_users_and_devices.update(map(tuple, txn.fetchall()))
+ if len(devices) == 1 and devices[0] == "*":
+ # Handle wildcard device_ids.
+ sql = (
+ "SELECT device_id FROM devices"
+ " WHERE user_id = ?"
+ )
+ txn.execute(sql, (user_id,))
+ message_json = ujson.dumps(messages_by_device["*"])
+ for row in txn.fetchall():
+ # Add the message for all devices for this user on this
+ # server.
+ device = row[0]
+ messages_json_for_user[device] = message_json
+ else:
+ sql = (
+ "SELECT device_id FROM devices"
+ " WHERE user_id = ? AND device_id IN ("
+ + ",".join("?" * len(devices))
+ + ")"
+ )
+ # TODO: Maybe this needs to be done in batches if there are
+ # too many local devices for a given user.
+ txn.execute(sql, [user_id] + devices)
+ for row in txn.fetchall():
+ # Only insert into the local inbox if the device exists on
+ # this server
+ device = row[0]
+ message_json = ujson.dumps(messages_by_device[device])
+ messages_json_for_user[device] = message_json
+
+ local_by_user_then_device[user_id] = messages_json_for_user
sql = (
"INSERT INTO device_inbox"
@@ -150,13 +172,9 @@ class DeviceInboxStore(SQLBaseStore):
" VALUES (?,?,?,?)"
)
rows = []
- for user_id, messages_by_device in messages_by_user_then_device.items():
- for device_id, message in messages_by_device.items():
- message_json = ujson.dumps(message)
- # Only insert into the local inbox if the device exists on
- # this server
- if (user_id, device_id) in local_users_and_devices:
- rows.append((user_id, device_id, stream_id, message_json))
+ for user_id, messages_by_device in local_by_user_then_device.items():
+ for device_id, message_json in messages_by_device.items():
+ rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
|