diff options
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r-- | synapse/storage/databases/main/devices.py | 63 |
1 files changed, 46 insertions, 17 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index bc7e876047..b2a5cd9a65 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -53,6 +53,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) +issue_8631_logger = logging.getLogger("synapse.8631_debug") DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( "drop_device_list_streams_non_unique_indexes" @@ -229,6 +230,12 @@ class DeviceWorkerStore(SQLBaseStore): if not updates: return now_stream_id, [] + if issue_8631_logger.isEnabledFor(logging.DEBUG): + data = {(user, device): stream_id for user, device, stream_id, _ in updates} + issue_8631_logger.debug( + "device updates need to be sent to %s: %s", destination, data + ) + # get the cross-signing keys of the users in the list, so that we can # determine which of the device changes were cross-signing keys users = {r[0] for r in updates} @@ -365,6 +372,17 @@ class DeviceWorkerStore(SQLBaseStore): # and remove the length budgeting above. results.append(("org.matrix.signing_key_update", result)) + if issue_8631_logger.isEnabledFor(logging.DEBUG): + for (user_id, edu) in results: + issue_8631_logger.debug( + "device update to %s for %s from %s to %s: %s", + destination, + user_id, + from_stream_id, + last_processed_stream_id, + edu, + ) + return last_processed_stream_id, results def _get_device_updates_by_remote_txn( @@ -781,7 +799,7 @@ class DeviceWorkerStore(SQLBaseStore): @cached(max_entries=10000) async def get_device_list_last_stream_id_for_remote( self, user_id: str - ) -> Optional[Any]: + ) -> Optional[str]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ @@ -797,7 +815,9 @@ class DeviceWorkerStore(SQLBaseStore): cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", ) - async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]): + async def get_device_list_last_stream_id_for_remotes( + self, user_ids: Iterable[str] + ) -> Dict[str, Optional[str]]: rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -1384,6 +1404,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): content: JsonDict, stream_id: str, ) -> None: + """Delete, update or insert a cache entry for this (user, device) pair.""" if content.get("deleted"): self.db_pool.simple_delete_txn( txn, @@ -1443,6 +1464,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def _update_remote_device_list_cache_txn( self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int ) -> None: + """Replace the list of cached devices for this user with the given list.""" self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) @@ -1450,12 +1472,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="device_lists_remote_cache", + keys=("user_id", "device_id", "content"), values=[ - { - "user_id": user_id, - "device_id": content["device_id"], - "content": json_encoder.encode(content), - } + (user_id, content["device_id"], json_encoder.encode(content)) for content in devices ], ) @@ -1543,8 +1562,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="device_lists_stream", + keys=("stream_id", "user_id", "device_id"), values=[ - {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} + (stream_id, user_id, device_id) for stream_id, device_id in zip(stream_ids, device_ids) ], ) @@ -1571,18 +1591,27 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", + keys=( + "destination", + "stream_id", + "user_id", + "device_id", + "sent", + "ts", + "opentracing_context", + ), values=[ - { - "destination": destination, - "stream_id": next(next_stream_id), - "user_id": user_id, - "device_id": device_id, - "sent": False, - "ts": now, - "opentracing_context": json_encoder.encode(context) + ( + destination, + next(next_stream_id), + user_id, + device_id, + False, + now, + json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", - } + ) for destination in hosts for device_id in device_ids ], |