summary refs log tree commit diff
path: root/synapse/storage/databases/main/appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/appservice.py')
-rw-r--r--synapse/storage/databases/main/appservice.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index abea4383c7..55e1ab099d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict
+from synapse.types import DeviceListUpdates, JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import _CacheContext, cached
 
@@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
         to_device_messages: List[JsonDict],
         one_time_key_counts: TransactionOneTimeKeyCounts,
         unused_fallback_keys: TransactionUnusedFallbackKeys,
+        device_list_summary: DeviceListUpdates,
     ) -> AppServiceTransaction:
         """Atomically creates a new transaction for this application service
         with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore(
                 appservice devices in the transaction.
             unused_fallback_keys: Lists of unused fallback keys for relevant
                 appservice devices in the transaction.
+            device_list_summary: The device list summary to include in the transaction.
 
         Returns:
             A new transaction.
@@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore(
                 to_device_messages=to_device_messages,
                 one_time_key_counts=one_time_key_counts,
                 unused_fallback_keys=unused_fallback_keys,
+                device_list_summary=device_list_summary,
             )
 
         return await self.db_pool.runInteraction(
@@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore(
 
         events = await self.get_events_as_list(event_ids)
 
-        # TODO: to-device messages, one-time key counts and unused fallback keys
-        #       are not yet populated for catch-up transactions.
+        # TODO: to-device messages, one-time key counts, device list summaries and unused
+        #       fallback keys are not yet populated for catch-up transactions.
         #       We likely want to populate those for reliability.
         return AppServiceTransaction(
             service=service,
@@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore(
             to_device_messages=[],
             one_time_key_counts={},
             unused_fallback_keys={},
+            device_list_summary=DeviceListUpdates(),
         )
 
     def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
-        if type not in ("read_receipt", "presence", "to_device"):
+        if type not in ("read_receipt", "presence", "to_device", "device_list"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (type,)
@@ -458,7 +462,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def set_appservice_stream_type_pos(
         self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if stream_type not in ("read_receipt", "presence", "to_device"):
+        if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (stream_type,)