diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index ec06620efb..411e47d98d 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep
@@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
- SlavedRegistrationStore,
+ SlavedRegistrationStore, SlavedDeviceStore,
):
pass
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 4dfc2dc648..9d250502e0 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
@@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore,
SlavedPresenceStore,
SlavedDeviceInboxStore,
+ SlavedDeviceStore,
RoomStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
@@ -380,6 +382,28 @@ class SynchrotronServer(HomeServer):
stream_key, position, users=users, rooms=rooms
)
+ @defer.inlineCallbacks
+ def notify_device_list_update(result):
+ stream = result.get("device_lists")
+ if not stream:
+ return
+
+ position_index = stream["field_names"].index("position")
+ user_index = stream["field_names"].index("user_id")
+
+ for row in stream["rows"]:
+ logger.info("Handling device list row: %r", row)
+ position = row[position_index]
+ user_id = row[user_index]
+
+ rooms = yield store.get_rooms_for_user(user_id)
+ room_ids = [r.room_id for r in rooms]
+
+ notifier.on_new_event(
+ "device_list_key", position, rooms=room_ids,
+ )
+
+ @defer.inlineCallbacks
def notify(result):
stream = result.get("events")
if stream:
@@ -417,6 +441,7 @@ class SynchrotronServer(HomeServer):
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
+ yield notify_device_list_update(result)
while True:
try:
@@ -427,7 +452,7 @@ class SynchrotronServer(HomeServer):
yield store.process_replication(result)
typing_handler.process_replication(result)
yield presence_handler.process_replication(result)
- notify(result)
+ yield notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ed077c9a76..6fefb85890 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -221,22 +221,6 @@ class DeviceHandler(BaseHandler):
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
- def get_device_list_changes(self, user_id, room_ids, from_key):
- """For a user and their joined rooms, calculate which device updates
- we need to return.
- """
- room_ids = frozenset(room_ids)
-
- user_ids_changed = set()
- changed = yield self.store.get_user_whose_devices_changed(from_key)
- for other_user_id in changed:
- other_rooms = yield self.store.get_rooms_for_user(other_user_id)
- if room_ids.intersection(e.room_id for e in other_rooms):
- user_ids_changed.add(other_user_id)
-
- defer.returnValue(user_ids_changed)
-
- @defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 06bf626367..9199f20817 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -144,7 +144,6 @@ class SyncHandler(object):
self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs)
self.state = hs.get_state_handler()
- self.device_handler = hs.get_device_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
@@ -546,15 +545,9 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder)
- if since_token and since_token.device_list_key:
- user_id = sync_config.user.to_string()
- rooms = yield self.store.get_rooms_for_user(user_id)
- joined_room_ids = set(r.room_id for r in rooms)
- device_lists = yield self.device_handler.get_device_list_changes(
- user_id, joined_room_ids, since_token.device_list_key
- )
- else:
- device_lists = []
+ device_lists = yield self._generate_sync_entry_for_device_list(
+ sync_result_builder
+ )
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
@@ -568,6 +561,28 @@ class SyncHandler(object):
))
@defer.inlineCallbacks
+ def _generate_sync_entry_for_device_list(self, sync_result_builder):
+ user_id = sync_result_builder.sync_config.user.to_string()
+ since_token = sync_result_builder.since_token
+
+ if since_token and since_token.device_list_key:
+ rooms = yield self.store.get_rooms_for_user(user_id)
+ room_ids = set(r.room_id for r in rooms)
+
+ user_ids_changed = set()
+ changed = yield self.store.get_user_whose_devices_changed(
+ since_token.device_list_key
+ )
+ for other_user_id in changed:
+ other_rooms = yield self.store.get_rooms_for_user(other_user_id)
+ if room_ids.intersection(e.room_id for e in other_rooms):
+ user_ids_changed.add(other_user_id)
+
+ defer.returnValue(user_ids_changed)
+ else:
+ defer.returnValue([])
+
+ @defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates
`sync_result_builder` with the result.
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 4616e9b34a..36548c5eda 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -46,6 +46,7 @@ STREAM_NAMES = (
("to_device",),
("public_rooms",),
("federation",),
+ ("device_lists",),
)
@@ -140,6 +141,7 @@ class ReplicationResource(Resource):
caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token()
+ device_list_token = self.store.get_device_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@@ -155,6 +157,7 @@ class ReplicationResource(Resource):
int(stream_token.to_device_key),
int(public_rooms_token),
int(federation_token),
+ int(device_list_token),
))
@request_handler()
@@ -214,6 +217,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams)
+ yield self.device_lists(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams)
@@ -495,6 +499,20 @@ class ReplicationResource(Resource):
"position", "type", "content",
), position=upto_token)
+ @defer.inlineCallbacks
+ def device_lists(self, writer, current_token, limit, request_streams):
+ current_position = current_token.device_lists
+
+ device_lists = request_streams.get("device_lists")
+
+ if device_lists is not None and device_lists != current_position:
+ changes = yield self.store.get_users_and_hosts_device_list_changes(
+ device_lists,
+ )
+ writer.write_header_and_rows("device_lists", changes, (
+ "position", "user_id", "destination",
+ ), position=current_position)
+
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
@@ -527,7 +545,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
- "federation",
+ "federation", "device_lists",
))):
__slots__ = []
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
new file mode 100644
index 0000000000..ca46aa17b6
--- /dev/null
+++ b/synapse/replication/slave/storage/devices.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedDeviceStore(BaseSlavedStore):
+ def __init__(self, db_conn, hs):
+ super(SlavedDeviceStore, self).__init__(db_conn, hs)
+
+ self.hs = hs
+
+ self._device_list_id_gen = SlavedIdTracker(
+ db_conn, "device_lists_stream", "stream_id",
+ )
+ device_list_max = self._device_list_id_gen.get_current_token()
+ self._device_list_stream_cache = StreamChangeCache(
+ "DeviceListStreamChangeCache", device_list_max,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache", device_list_max,
+ )
+
+ get_device_stream_token = DataStore.get_device_stream_token.__func__
+ get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
+ get_devices_by_remote = DataStore.get_devices_by_remote.__func__
+ _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
+ _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
+ mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
+ _mark_as_sent_devices_by_remote_txn = (
+ DataStore._mark_as_sent_devices_by_remote_txn.__func__
+ )
+
+ def stream_positions(self):
+ result = super(SlavedDeviceStore, self).stream_positions()
+ result["device_lists"] = self._device_list_id_gen.get_current_token()
+ return result
+
+ def process_replication(self, result):
+ stream = result.get("device_lists")
+ if stream:
+ self._device_list_id_gen.advance(int(stream["position"]))
+ for row in stream["rows"]:
+ stream_id = row[0]
+ user_id = row[1]
+ destination = row[2]
+
+ self._device_list_stream_cache.entity_has_changed(
+ user_id, stream_id
+ )
+
+ if destination:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ destination, stream_id
+ )
+
+ return super(SlavedDeviceStore, self).process_replication(result)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 00317b0c1f..2b2cebacfa 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -458,6 +458,21 @@ class DeviceStore(SQLBaseStore):
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row["user_id"] for row in rows))
+ def get_users_and_hosts_device_list_changes(self, from_key):
+ """Return a list of `(stream_id, user_id, destination)` which is the
+ combined list of changes to devices, and which destinations need to be
+ poked. `destination` may be None if no destinations need to be poked.
+ """
+ sql = """
+ SELECT stream_id, user_id, destination FROM device_lists_stream
+ LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ WHERE stream_id > ?
+ """
+ return self._execute(
+ "get_users_and_hosts_device_list", None,
+ sql, from_key,
+ )
+
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
|