diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 18358eca46..a5bb4d404e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
import logging
from typing import (
TYPE_CHECKING,
@@ -39,6 +38,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -49,11 +49,19 @@ from synapse.storage.database import (
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.stream_change_cache import (
+ AllEntitiesChangedResult,
+ StreamChangeCache,
+)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -80,9 +88,23 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
):
super().__init__(database, db_conn, hs)
+ # In the worker store this is an ID tracker which we overwrite in the non-worker
+ # class below that is used on the main process.
+ self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
+ ],
+ is_writer=hs.config.worker.worker_app is None,
+ )
+
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
- device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
+ device_list_max = self._device_list_id_gen.get_current_token()
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
@@ -136,6 +158,39 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
+ if stream_name == DeviceListsStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+ self._invalidate_caches_for_devices(token, rows)
+ elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def _invalidate_caches_for_devices(
+ self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
+ ) -> None:
+ for row in rows:
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
+
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
+
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@@ -274,6 +329,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
destination, int(from_stream_id)
)
if not has_changed:
+ # debugging for https://github.com/matrix-org/synapse/issues/14251
+ issue_8631_logger.debug(
+ "%s: no change between %i and %i",
+ destination,
+ from_stream_id,
+ now_stream_id,
+ )
return now_stream_id, []
updates = await self.db_pool.runInteraction(
@@ -466,7 +528,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
limit: Maximum number of device updates to return
Returns:
- List: List of device update tuples:
+ List of device update tuples:
- user_id
- device_id
- stream_id
@@ -539,9 +601,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
- "org.matrix.opentracing_context": opentracing_context,
}
+ if opentracing_context != "{}":
+ result["org.matrix.opentracing_context"] = opentracing_context
+
prev_id = stream_id
if device is not None:
@@ -549,7 +613,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if keys:
result["keys"] = keys
- device_display_name = device.display_name
+ device_display_name = None
+ if (
+ self.hs.config.federation.allow_device_name_lookup_over_federation
+ ):
+ device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@@ -664,11 +732,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
},
)
- @abc.abstractmethod
- def get_device_stream_token(self) -> int:
- """Get the current stream id from the _device_list_id_gen"""
- ...
-
@trace
@cancellable
async def get_user_devices_from_cache(
@@ -739,7 +802,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_cached_device_list_changes(
self,
from_key: int,
- ) -> Optional[List[str]]:
+ ) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
@@ -747,10 +810,58 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable
+ async def get_all_devices_changed(
+ self,
+ from_key: int,
+ to_key: int,
+ ) -> Set[str]:
+ """Get all users whose devices have changed in the given range.
+
+ Args:
+ from_key: The minimum device lists stream token to query device list
+ changes for, exclusive.
+ to_key: The maximum device lists stream token to query device list
+ changes for, inclusive.
+
+ Returns:
+ The set of user_ids whose devices have changed since `from_key`
+ (exclusive) until `to_key` (inclusive).
+ """
+
+ result = self._device_list_stream_cache.get_all_entities_changed(from_key)
+
+ if result.hit:
+ # We know which users might have changed devices.
+ if not result.entities:
+ # If no users then we can return early.
+ return set()
+
+ # Otherwise we need to filter down the list
+ return await self.get_users_whose_devices_changed(
+ from_key, result.entities, to_key
+ )
+
+ # If the cache didn't tell us anything, we just need to query the full
+ # range.
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ """
+
+ rows = await self.db_pool.execute(
+ "get_all_devices_changed",
+ None,
+ sql,
+ from_key,
+ to_key,
+ )
+ return {u for u, in rows}
+
+ @cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
- user_ids: Optional[Collection[str]] = None,
+ user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@@ -770,46 +881,31 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- user_ids_to_check: Optional[Collection[str]]
- if user_ids is None:
- # Get set of all users that have had device list changes since 'from_key'
- user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
- from_key
- )
- else:
- # The same as above, but filter results to only those users in 'user_ids'
- user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
+ # If an empty set was returned, there's nothing to do.
if not user_ids_to_check:
return set()
- def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
- changes: Set[str] = set()
-
- stream_id_where_clause = "stream_id > ?"
- sql_args = [from_key]
-
- if to_key:
- stream_id_where_clause += " AND stream_id <= ?"
- sql_args.append(to_key)
+ if to_key is None:
+ to_key = self._device_list_id_gen.get_current_token()
- sql = f"""
+ def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
+ sql = """
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE {stream_id_where_clause}
- AND
+ WHERE ? < stream_id AND stream_id <= ? AND %s
"""
+ changes: Set[str] = set()
+
# Query device changes with a batch of users at a time
- # Assertion for mypy's benefit; see also
- # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
- assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, sql_args + args)
+ txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn)
return changes
@@ -1381,6 +1477,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._remove_duplicate_outbound_pokes,
)
+ self.db_pool.updates.register_background_index_update(
+ "device_lists_changes_in_room_by_room_index",
+ index_name="device_lists_changes_in_room_by_room_idx",
+ table="device_lists_changes_in_room",
+ columns=["room_id", "stream_id"],
+ )
+
async def _drop_device_list_streams_non_unique_indexes(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1468,6 +1571,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+ # Because we have write access, this will be a StreamIdGenerator
+ # (see DeviceWorkerStore.__init__)
+ _device_list_id_gen: AbstractStreamIdGenerator
+
def __init__(
self,
database: DatabasePool,
@@ -1673,9 +1780,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
values={"content": json_encoder.encode(content)},
- # we don't need to lock, because we assume we are the only thread
- # updating this user's devices.
- lock=False,
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
@@ -1689,9 +1793,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
- # again, we can assume we are the only thread updating this user's
- # extremity.
- lock=False,
)
async def update_remote_device_list_cache(
@@ -1744,9 +1845,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
- # we don't need to lock, because we can assume we are the only thread
- # updating this user's extremity.
- lock=False,
)
async def add_device_change_to_streams(
@@ -1792,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context,
)
- async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1842,7 +1940,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self,
txn: LoggingTransaction,
user_id: str,
- device_ids: Iterable[str],
+ device_id: str,
hosts: Collection[str],
stream_ids: List[int],
context: Optional[Dict[str, str]],
@@ -1858,6 +1956,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
+ mark_sent = not self.hs.is_mine_id(user_id)
+
+ values = [
+ (
+ destination,
+ next(stream_id_iterator),
+ user_id,
+ device_id,
+ mark_sent,
+ now,
+ encoded_context if whitelisted_homeserver(destination) else "{}",
+ )
+ for destination in hosts
+ ]
+
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1870,23 +1983,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"ts",
"opentracing_context",
),
- values=[
- (
- destination,
- next(stream_id_iterator),
- user_id,
- device_id,
- not self.hs.is_mine_id(
- user_id
- ), # We only need to send out update for *our* users
- now,
- encoded_context if whitelisted_homeserver(destination) else "{}",
- )
- for destination in hosts
- for device_id in device_ids
- ],
+ values=values,
)
+ # debugging for https://github.com/matrix-org/synapse/issues/14251
+ if issue_8631_logger.isEnabledFor(logging.DEBUG):
+ issue_8631_logger.debug(
+ "Recorded outbound pokes for %s:%s with device stream ids %s",
+ user_id,
+ device_id,
+ {
+ stream_id: destination
+ for (destination, stream_id, _, _, _, _, _) in values
+ },
+ )
+
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
@@ -1931,27 +2042,48 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def get_uncoverted_outbound_room_pokes(
- self, limit: int = 10
+ self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.
+ Args:
+ start_stream_id: Together with `start_room_id`, indicates the position after
+ which to return device list changes.
+ start_room_id: Together with `start_stream_id`, indicates the position after
+ which to return device list changes.
+ limit: The maximum number of device list changes to return.
+
Returns:
- A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ A list of user ID, device ID, room ID, stream ID and optional opentracing
+ context, in order of ascending (stream ID, room ID).
"""
sql = """
SELECT user_id, device_id, room_id, stream_id, opentracing_context
FROM device_lists_changes_in_room
- WHERE NOT converted_to_destinations
- ORDER BY stream_id
+ WHERE
+ (stream_id, room_id) > (?, ?) AND
+ stream_id <= ? AND
+ NOT converted_to_destinations
+ ORDER BY stream_id ASC, room_id ASC
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
- txn.execute(sql, (limit,))
+ txn.execute(
+ sql,
+ (
+ start_stream_id,
+ start_room_id,
+ # Avoid returning rows if there may be uncommitted device list
+ # changes with smaller stream IDs.
+ self._device_list_id_gen.get_current_token(),
+ limit,
+ ),
+ )
return [
(
@@ -1973,52 +2105,28 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
room_id: str,
- stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
-
- Marks the associated row in `device_lists_changes_in_room` as handled,
- if `stream_id` is provided.
"""
+ if not hosts:
+ return
def add_device_list_outbound_pokes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
- if hosts:
- self._add_device_outbound_poke_to_stream_txn(
- txn,
- user_id=user_id,
- device_ids=[device_id],
- hosts=hosts,
- stream_ids=stream_ids,
- context=context,
- )
-
- if stream_id:
- self.db_pool.simple_update_txn(
- txn,
- table="device_lists_changes_in_room",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "room_id": room_id,
- },
- updatevalues={"converted_to_destinations": True},
- )
-
- if not hosts:
- # If there are no hosts then we don't try and generate stream IDs.
- return await self.db_pool.runInteraction(
- "add_device_list_outbound_pokes",
- add_device_list_outbound_pokes_txn,
- [],
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
+ user_id=user_id,
+ device_id=device_id,
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
)
- async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
@@ -2032,7 +2140,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates during partial joins.
"""
- async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.simple_upsert(
table="device_lists_remote_pending",
keyvalues={
@@ -2079,3 +2187,37 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"get_pending_remote_device_list_updates_for_room",
get_pending_remote_device_list_updates_for_room_txn,
)
+
+ async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
+ """
+ Get the position of the last row in `device_list_changes_in_room` that has been
+ converted to `device_lists_outbound_pokes`.
+
+ Rows with a strictly greater position where `converted_to_destinations` is
+ `FALSE` have not been converted.
+ """
+
+ row = await self.db_pool.simple_select_one(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ retcols=["stream_id", "room_id"],
+ desc="get_device_change_last_converted_pos",
+ )
+ return row["stream_id"], row["room_id"]
+
+ async def set_device_change_last_converted_pos(
+ self,
+ stream_id: int,
+ room_id: str,
+ ) -> None:
+ """
+ Set the position of the last row in `device_list_changes_in_room` that has been
+ converted to `device_lists_outbound_pokes`.
+ """
+
+ await self.db_pool.simple_update_one(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id, "room_id": room_id},
+ desc="set_device_change_last_converted_pos",
+ )
|