diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index c38666da18..6324df883b 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -97,6 +97,7 @@ BOOLEAN_COLUMNS = {
"users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"],
+ "device_lists_changes_in_room": ["converted_to_destinations"],
}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 0f90302c95..b3a9e50752 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -680,6 +680,14 @@ class ServerConfig(Config):
config.get("use_account_validity_in_account_status") or False
)
+ # This is a temporary option that enables fully using the new
+ # `device_lists_changes_in_room` without the backwards compat code. This
+ # is primarily for testing. If enabled the server should *not* be
+ # downgraded, as it may lead to missing device list updates.
+ self.use_new_device_lists_changes_in_room = (
+ config.get("use_new_device_lists_changes_in_room") or False
+ )
+
self.rooms_to_exclude_from_sync: List[str] = (
config.get("exclude_rooms_from_sync") or []
)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d5ccaa0c37..c710c02cf9 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -37,7 +37,10 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import (
JsonDict,
StreamToken,
@@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
+ # Whether `_handle_new_device_update_async` is currently processing.
+ self._handle_new_device_update_is_processing = False
+
+ # If a new device update may have happened while the loop was
+ # processing.
+ self._handle_new_device_update_new_data = False
+
+ # On start up check if there are any updates pending.
+ hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+
+ # Used to decide if we calculate outbound pokes up front or not. By
+ # default we do to allow safely downgrading Synapse.
+ self.use_new_device_lists_changes_in_room = (
+ hs.config.server.use_new_device_lists_changes_in_room
+ )
+
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
@@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
# No changes to notify about, so this is a no-op.
return
- users_who_share_room = await self.store.get_users_who_share_room_with_user(
- user_id
- )
+ room_ids = await self.store.get_rooms_for_user(user_id)
+
+ hosts: Optional[Set[str]] = None
+ if not self.use_new_device_lists_changes_in_room:
+ hosts = set()
- hosts: Set[str] = set()
- if self.hs.is_mine_id(user_id):
- hosts.update(get_domain_from_id(u) for u in users_who_share_room)
- hosts.discard(self.server_name)
+ if self.hs.is_mine_id(user_id):
+ for room_id in room_ids:
+ joined_users = await self.store.get_users_in_room(room_id)
+ hosts.update(get_domain_from_id(u) for u in joined_users)
- set_tag("target_hosts", hosts)
+ set_tag("target_hosts", hosts)
+
+ hosts.discard(self.server_name)
position = await self.store.add_device_change_to_streams(
- user_id, device_ids, list(hosts)
+ user_id,
+ device_ids,
+ hosts=hosts,
+ room_ids=room_ids,
)
if not position:
@@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
- users_to_notify = users_who_share_room.union({user_id})
+ self.notifier.on_new_event(
+ "device_list_key", position, users={user_id}, rooms=room_ids
+ )
- self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
+ # We may need to do some processing asynchronously.
+ self._handle_new_device_update_async()
if hosts:
logger.info(
@@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
+ @wrap_as_background_process("_handle_new_device_update_async")
+ async def _handle_new_device_update_async(self) -> None:
+ """Called when we have a new local device list update that we need to
+ send out over federation.
+
+ This happens in the background so as not to block the original request
+ that generated the device update.
+ """
+ if self._handle_new_device_update_is_processing:
+ self._handle_new_device_update_new_data = True
+ return
+
+ self._handle_new_device_update_is_processing = True
+
+ # The stream ID we processed previous iteration (if any), and the set of
+ # hosts we've already poked about for this update. This is so that we
+ # don't poke the same remote server about the same update repeatedly.
+ current_stream_id = None
+ hosts_already_sent_to: Set[str] = set()
+
+ try:
+ while True:
+ self._handle_new_device_update_new_data = False
+ rows = await self.store.get_uncoverted_outbound_room_pokes()
+ if not rows:
+ # If the DB returned nothing then there is nothing left to
+ # do, *unless* a new device list update happened during the
+ # DB query.
+ if self._handle_new_device_update_new_data:
+ continue
+ else:
+ return
+
+ for user_id, device_id, room_id, stream_id, opentracing_context in rows:
+ joined_user_ids = await self.store.get_users_in_room(room_id)
+ hosts = {get_domain_from_id(u) for u in joined_user_ids}
+ hosts.discard(self.server_name)
+
+ # Check if we've already sent this update to some hosts
+ if current_stream_id == stream_id:
+ hosts -= hosts_already_sent_to
+
+ await self.store.add_device_list_outbound_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ stream_id=stream_id,
+ hosts=hosts,
+ context=opentracing_context,
+ )
+
+ # Notify replication that we've updated the device list stream.
+ self.notifier.notify_replication()
+
+ if hosts:
+ logger.info(
+ "Sending device list update notif for %r to: %r",
+ user_id,
+ hosts,
+ )
+ for host in hosts:
+ self.federation_sender.send_device_messages(
+ host, immediate=False
+ )
+ log_kv(
+ {"message": "sent device update to host", "host": host}
+ )
+
+ if current_stream_id != stream_id:
+ # Clear the set of hosts we've already sent to as we're
+ # processing a new update.
+ hosts_already_sent_to.clear()
+
+ hosts_already_sent_to.update(hosts)
+ current_stream_id = stream_id
+
+ finally:
+ self._handle_new_device_update_is_processing = False
+
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0ffd34f1da..f040e33bfb 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -44,6 +44,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
],
)
device_list_max = self._device_list_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1ea0b2aa6f..cdbe3872fa 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -146,6 +146,7 @@ class DataStore(
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
],
)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index f08f7834d3..07eea4b3d2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
LIMIT ?
"""
@@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
- self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+ self,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: Optional[Collection[str]],
+ room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
- hosts: The remote destinations that should be notified of the change.
+ hosts: The remote destinations that should be notified of the change. If
+ None then the set of hosts have *not* been calculated, and will be
+ calculated later by a background task.
+ room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
- async with self._device_list_id_gen.get_next_mult(
- len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_change_to_stream",
- self._add_device_change_to_stream_txn,
+ context = get_active_span_text_map()
+
+ def add_device_changes_txn(
+ txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+ ):
+ self._add_device_change_to_stream_txn(
+ txn,
user_id,
device_ids,
- stream_ids,
+ stream_ids_for_device_change,
)
- if not hosts:
- return stream_ids[-1]
+ self._add_device_outbound_room_poke_txn(
+ txn,
+ user_id,
+ device_ids,
+ room_ids,
+ stream_ids_for_device_change,
+ context,
+ hosts_have_been_calculated=hosts is not None,
+ )
- context = get_active_span_text_map()
- async with self._device_list_id_gen.get_next_mult(
- len(hosts) * len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_outbound_poke_to_stream",
- self._add_device_outbound_poke_to_stream_txn,
+ # If the set of hosts to send to has not been calculated yet (and so
+ # `hosts` is None) or there are no `hosts` to send to, then skip
+ # trying to persist them to the DB.
+ if not hosts:
+ return
+
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
user_id,
device_ids,
hosts,
- stream_ids,
+ stream_ids_for_outbound_pokes,
context,
)
+ # `device_lists_stream` wants a stream ID per device update.
+ num_stream_ids = len(device_ids)
+
+ if hosts:
+ # `device_lists_outbound_pokes` wants a different stream ID for
+ # each row, which is a row per host per device update.
+ num_stream_ids += len(hosts) * len(device_ids)
+
+ async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+ stream_ids_for_device_change = stream_ids[: len(device_ids)]
+ stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+ await self.db_pool.runInteraction(
+ "add_device_change_to_stream",
+ add_device_changes_txn,
+ stream_ids_for_device_change,
+ stream_ids_for_outbound_pokes,
+ )
+
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
- stream_ids: List[str],
+ stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
- next_stream_id = iter(stream_ids)
+ stream_id_iterator = iter(stream_ids)
+ encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
- next(next_stream_id),
+ next(stream_id_iterator),
user_id,
device_id,
False,
now,
- json_encoder.encode(context)
- if whitelisted_homeserver(destination)
- else "{}",
+ encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
+
+ def _add_device_outbound_room_poke_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Iterable[str],
+ room_ids: Collection[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
+ hosts_have_been_calculated: bool,
+ ) -> None:
+ """Record the user in the room has updated their device.
+
+ Args:
+ hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+ has been updated already with the updates.
+ """
+
+ # We only need to convert to outbound pokes if they are our user.
+ converted_to_destinations = (
+ hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+ )
+
+ encoded_context = json_encoder.encode(context)
+
+ # The `device_lists_changes_in_room.stream_id` column matches the
+ # corresponding `stream_id` of the update in the `device_lists_stream`
+ # table, i.e. all rows persisted for the same device update will have
+ # the same `stream_id` (but different room IDs).
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keys=(
+ "user_id",
+ "device_id",
+ "room_id",
+ "stream_id",
+ "converted_to_destinations",
+ "opentracing_context",
+ ),
+ values=[
+ (
+ user_id,
+ device_id,
+ room_id,
+ stream_id,
+ converted_to_destinations,
+ encoded_context,
+ )
+ for room_id in room_ids
+ for device_id, stream_id in zip(device_ids, stream_ids)
+ ],
+ )
+
+ async def get_uncoverted_outbound_room_pokes(
+ self, limit: int = 10
+ ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+ """Get device list changes by room that have not yet been handled and
+ written to `device_lists_outbound_pokes`.
+
+ Returns:
+ A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ """
+
+ sql = """
+ SELECT user_id, device_id, room_id, stream_id, opentracing_context
+ FROM device_lists_changes_in_room
+ WHERE NOT converted_to_destinations
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ def get_uncoverted_outbound_room_pokes_txn(txn):
+ txn.execute(sql, (limit,))
+ return txn.fetchall()
+
+ return await self.db_pool.runInteraction(
+ "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+ )
+
+ async def add_device_list_outbound_pokes(
+ self,
+ user_id: str,
+ device_id: str,
+ room_id: str,
+ stream_id: int,
+ hosts: Collection[str],
+ context: Optional[Dict[str, str]],
+ ) -> None:
+ """Queue the device update to be sent to the given set of hosts,
+ calculated from the room ID.
+
+ Marks the associated row in `device_lists_changes_in_room` as handled.
+ """
+
+ def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+ if hosts:
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
+ user_id=user_id,
+ device_ids=[device_id],
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
+ )
+
+ self.db_pool.simple_update_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "room_id": room_id,
+ },
+ updatevalues={"converted_to_destinations": True},
+ )
+
+ if not hosts:
+ # If there are no hosts then we don't try and generate stream IDs.
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ [],
+ )
+
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ stream_ids,
+ )
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index ea900e0f3d..151f2aa9bb 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -60,6 +60,7 @@ Changes in SCHEMA_VERSION = 68:
new events.
Changes in SCHEMA_VERSION = 69:
+ - We now write to `device_lists_changes_in_room` table.
- Use sequence to generate future `application_services_txns.txn_id`s
"""
diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
new file mode 100644
index 0000000000..b5b1782b2a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
@@ -0,0 +1,38 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_lists_changes_in_room (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+
+ -- This initially matches `device_lists_stream.stream_id`. Note that we
+ -- delete older values from `device_lists_stream`, so we can't use a foreign
+ -- constraint here.
+ --
+ -- The table will contain rows with the same `stream_id` but different
+ -- `room_id`, as for each device update we store a row per room the user is
+ -- joined to. Therefore `(stream_id, room_id)` gives a unique index.
+ stream_id BIGINT NOT NULL,
+
+ -- We have a background process which goes through this table and converts
+ -- entries into rows in `device_lists_outbound_pokes`. Once we have processed
+ -- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
+ converted_to_destinations BOOLEAN NOT NULL,
+ opentracing_context TEXT
+);
+
+CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
+CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;
|