diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index ca0fe8c4be..05a193f889 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import (
TYPE_CHECKING,
@@ -39,6 +38,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -47,12 +47,19 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -69,7 +76,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
-class DeviceWorkerStore(EndToEndKeyWorkerStore):
+class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -78,9 +85,23 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
):
super().__init__(database, db_conn, hs)
+ # In the worker store this is an ID tracker which we overwrite in the non-worker
+ # class below that is used on the main process.
+ self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
+ ],
+ is_writer=hs.config.worker.worker_app is None,
+ )
+
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
- device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
+ device_list_max = self._device_list_id_gen.get_current_token()
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
@@ -134,6 +155,39 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
+ if stream_name == DeviceListsStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+ self._invalidate_caches_for_devices(token, rows)
+ elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def _invalidate_caches_for_devices(
+ self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
+ ) -> None:
+ for row in rows:
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
+
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
+
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@@ -272,6 +326,13 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
destination, int(from_stream_id)
)
if not has_changed:
+ # debugging for https://github.com/matrix-org/synapse/issues/14251
+ issue_8631_logger.debug(
+ "%s: no change between %i and %i",
+ destination,
+ from_stream_id,
+ now_stream_id,
+ )
return now_stream_id, []
updates = await self.db_pool.runInteraction(
@@ -464,7 +525,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
limit: Maximum number of device updates to return
Returns:
- List: List of device update tuples:
+ List of device update tuples:
- user_id
- device_id
- stream_id
@@ -537,9 +598,11 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
- "org.matrix.opentracing_context": opentracing_context,
}
+ if opentracing_context != "{}":
+ result["org.matrix.opentracing_context"] = opentracing_context
+
prev_id = stream_id
if device is not None:
@@ -547,7 +610,11 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
if keys:
result["keys"] = keys
- device_display_name = device.display_name
+ device_display_name = None
+ if (
+ self.hs.config.federation.allow_device_name_lookup_over_federation
+ ):
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@@ -662,12 +729,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
},
)
- @abc.abstractmethod
- def get_device_stream_token(self) -> int:
- """Get the current stream id from the _device_list_id_gen"""
- ...
-
@trace
+ @cancellable
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
@@ -743,6 +806,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
+ @cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
@@ -982,24 +1046,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
desc="mark_remote_user_device_cache_as_valid",
)
+ async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ await self.db_pool.runInteraction(
+ "_handle_potentially_left_users",
+ self.handle_potentially_left_users_txn,
+ user_ids,
+ )
+
+ def handle_potentially_left_users_txn(
+ self,
+ txn: LoggingTransaction,
+ user_ids: Set[str],
+ ) -> None:
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
+ left_users = user_ids - joined_users
+
+ for user_id in left_users:
+ self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
+
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user."""
- def _mark_remote_user_device_list_as_unsubscribed_txn(
- txn: LoggingTransaction,
- ) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
- )
-
await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
- _mark_remote_user_device_list_as_unsubscribed_txn,
+ self.mark_remote_user_device_list_as_unsubscribed_txn,
+ user_id,
+ )
+
+ def mark_remote_user_device_list_as_unsubscribed_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ ) -> None:
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
async def get_dehydrated_device(
@@ -1221,6 +1320,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
desc="get_min_device_lists_changes_in_room",
)
+ @cancellable
async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int
) -> Optional[Set[str]]:
@@ -1267,6 +1367,33 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
return changes
+ async def get_device_list_changes_in_room(
+ self, room_id: str, min_stream_id: int
+ ) -> Collection[Tuple[str, str]]:
+ """Get all device list changes that happened in the room since the given
+ stream ID.
+
+ Returns:
+ Collection of user ID/device ID tuples of all devices that have
+ changed
+ """
+
+ sql = """
+ SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
+ WHERE room_id = ? AND stream_id > ?
+ """
+
+ def get_device_list_changes_in_room_txn(
+ txn: LoggingTransaction,
+ ) -> Collection[Tuple[str, str]]:
+ txn.execute(sql, (room_id, min_stream_id))
+ return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction(
+ "get_device_list_changes_in_room",
+ get_device_list_changes_in_room_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@@ -1314,6 +1441,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._remove_duplicate_outbound_pokes,
)
+ self.db_pool.updates.register_background_index_update(
+ "device_lists_changes_in_room_by_room_index",
+ index_name="device_lists_changes_in_room_by_room_idx",
+ table="device_lists_changes_in_room",
+ columns=["room_id", "stream_id"],
+ )
+
async def _drop_device_list_streams_non_unique_indexes(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1401,6 +1535,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+ # Because we have write access, this will be a StreamIdGenerator
+ # (see DeviceWorkerStore.__init__)
+ _device_list_id_gen: AbstractStreamIdGenerator
+
def __init__(
self,
database: DatabasePool,
@@ -1725,7 +1863,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context,
)
- async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1775,7 +1913,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self,
txn: LoggingTransaction,
user_id: str,
- device_ids: Iterable[str],
+ device_id: str,
hosts: Collection[str],
stream_ids: List[int],
context: Optional[Dict[str, str]],
@@ -1791,6 +1929,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
+ mark_sent = not self.hs.is_mine_id(user_id)
+
+ values = [
+ (
+ destination,
+ next(stream_id_iterator),
+ user_id,
+ device_id,
+ mark_sent,
+ now,
+ encoded_context if whitelisted_homeserver(destination) else "{}",
+ )
+ for destination in hosts
+ ]
+
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1803,23 +1956,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"ts",
"opentracing_context",
),
- values=[
- (
- destination,
- next(stream_id_iterator),
- user_id,
- device_id,
- not self.hs.is_mine_id(
- user_id
- ), # We only need to send out update for *our* users
- now,
- encoded_context if whitelisted_homeserver(destination) else "{}",
- )
- for destination in hosts
- for device_id in device_ids
- ],
+ values=values,
)
+ # debugging for https://github.com/matrix-org/synapse/issues/14251
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ issue_8631_logger.debug(
+ "Recorded outbound pokes for %s:%s with device stream ids %s",
+ user_id,
+ device_id,
+ {
+ stream_id: destination
+ for (destination, stream_id, _, _, _, _, _) in values
+ },
+ )
+
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
@@ -1864,27 +2015,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def get_uncoverted_outbound_room_pokes(
- self, limit: int = 10
+ self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.
+ Args:
+ start_stream_id: Together with `start_room_id`, indicates the position after
+ which to return device list changes.
+ start_room_id: Together with `start_stream_id`, indicates the position after
+ which to return device list changes.
+ limit: The maximum number of device list changes to return.
+
Returns:
- A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ A list of user ID, device ID, room ID, stream ID and optional opentracing
+ context, in order of ascending (stream ID, room ID).
"""
sql = """
SELECT user_id, device_id, room_id, stream_id, opentracing_context
FROM device_lists_changes_in_room
- WHERE NOT converted_to_destinations
- ORDER BY stream_id
+ WHERE
+ (stream_id, room_id) > (?, ?) AND
+ stream_id <= ? AND
+ NOT converted_to_destinations
+ ORDER BY stream_id ASC, room_id ASC
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
- txn.execute(sql, (limit,))
+ txn.execute(
+ sql,
+ (
+ start_stream_id,
+ start_room_id,
+ # Avoid returning rows if there may be uncommitted device list
+ # changes with smaller stream IDs.
+ self._device_list_id_gen.get_current_token(),
+ limit,
+ ),
+ )
return [
(
@@ -1906,52 +2078,119 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
room_id: str,
- stream_id: int,
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
-
- Marks the associated row in `device_lists_changes_in_room` as handled.
"""
+ if not hosts:
+ return
def add_device_list_outbound_pokes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
- if hosts:
- self._add_device_outbound_poke_to_stream_txn(
- txn,
- user_id=user_id,
- device_ids=[device_id],
- hosts=hosts,
- stream_ids=stream_ids,
- context=context,
- )
-
- self.db_pool.simple_update_txn(
+ self._add_device_outbound_poke_to_stream_txn(
txn,
- table="device_lists_changes_in_room",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "room_id": room_id,
- },
- updatevalues={"converted_to_destinations": True},
+ user_id=user_id,
+ device_id=device_id,
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
)
- if not hosts:
- # If there are no hosts then we don't try and generate stream IDs.
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
- [],
+ stream_ids,
)
- async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
- return await self.db_pool.runInteraction(
- "add_device_list_outbound_pokes",
- add_device_list_outbound_pokes_txn,
- stream_ids,
+ async def add_remote_device_list_to_pending(
+ self, user_id: str, device_id: str
+ ) -> None:
+ """Add a device list update to the table tracking remote device list
+ updates during partial joins.
+ """
+
+ async with self._device_list_id_gen.get_next() as stream_id:
+ await self.db_pool.simple_upsert(
+ table="device_lists_remote_pending",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={"stream_id": stream_id},
+ desc="add_remote_device_list_to_pending",
)
+
+ async def get_pending_remote_device_list_updates_for_room(
+ self, room_id: str
+ ) -> Collection[Tuple[str, str]]:
+ """Get the set of remote device list updates from the pending table for
+ the room.
+ """
+
+ min_device_stream_id = await self.db_pool.simple_select_one_onecol(
+ table="partial_state_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="device_lists_stream_id",
+ desc="get_pending_remote_device_list_updates_for_room_device",
+ )
+
+ sql = """
+ SELECT user_id, device_id FROM device_lists_remote_pending AS d
+ INNER JOIN current_state_events AS c ON
+ type = 'm.room.member'
+ AND state_key = user_id
+ AND membership = 'join'
+ WHERE
+ room_id = ? AND stream_id > ?
+ """
+
+ def get_pending_remote_device_list_updates_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Collection[Tuple[str, str]]:
+ txn.execute(sql, (room_id, min_device_stream_id))
+ return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction(
+ "get_pending_remote_device_list_updates_for_room",
+ get_pending_remote_device_list_updates_for_room_txn,
+ )
+
+ async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
+ """
+ Get the position of the last row in `device_list_changes_in_room` that has been
+ converted to `device_lists_outbound_pokes`.
+
+ Rows with a strictly greater position where `converted_to_destinations` is
+ `FALSE` have not been converted.
+ """
+
+ row = await self.db_pool.simple_select_one(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ retcols=["stream_id", "room_id"],
+ desc="get_device_change_last_converted_pos",
+ )
+ return row["stream_id"], row["room_id"]
+
+ async def set_device_change_last_converted_pos(
+ self,
+ stream_id: int,
+ room_id: str,
+ ) -> None:
+ """
+ Set the position of the last row in `device_list_changes_in_room` that has been
+ converted to `device_lists_outbound_pokes`.
+ """
+
+ await self.db_pool.simple_update_one(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id, "room_id": room_id},
+ desc="set_device_change_last_converted_pos",
+ )
|