summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/__init__.py3
-rw-r--r--synapse/storage/databases/main/devices.py52
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py53
3 files changed, 60 insertions, 48 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 70cf15dd7f..e6536c8456 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -264,6 +264,9 @@ class DataStore(
         # Used in _generate_user_daily_visits to keep track of progress
         self._last_user_visit_update = self._get_start_of_day()
 
+    def get_device_stream_token(self) -> int:
+        return self._device_list_id_gen.get_current_token()
+
     def take_presence_startup_info(self):
         active_on_startup = self._presence_on_startup
         self._presence_on_startup = None
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index def96637a2..e8379c73c4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -14,6 +14,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
 from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
 
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
             update included in the response), and the list of updates, where
             each update is a pair of EDU type and EDU contents.
         """
-        now_stream_id = self._device_list_id_gen.get_current_token()
+        now_stream_id = self.get_device_stream_token()
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore):
             },
         )
 
+    @abc.abstractmethod
     def get_device_stream_token(self) -> int:
-        return self._device_list_id_gen.get_current_token()
+        """Get the current stream id from the _device_list_id_gen"""
+        ...
 
     @trace
     async def get_user_devices_from_cache(
@@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore):
             device["device_id"]: db_to_json(device["content"]) for device in devices
         }
 
-    def get_devices_with_keys_by_user(self, user_id: str):
-        """Get all devices (with any device keys) for a user
-
-        Returns:
-            Deferred which resolves to (stream_id, devices)
-        """
-        return self.db_pool.runInteraction(
-            "get_devices_with_keys_by_user",
-            self._get_devices_with_keys_by_user_txn,
-            user_id,
-        )
-
-    def _get_devices_with_keys_by_user_txn(
-        self, txn: LoggingTransaction, user_id: str
-    ) -> Tuple[int, List[JsonDict]]:
-        now_stream_id = self._device_list_id_gen.get_current_token()
-
-        devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
-
-        if devices:
-            user_devices = devices[user_id]
-            results = []
-            for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
-
-                key_json = device.get("key_json", None)
-                if key_json:
-                    result["keys"] = db_to_json(key_json)
-
-                    if "signatures" in device:
-                        for sig_user_id, sigs in device["signatures"].items():
-                            result["keys"].setdefault("signatures", {}).setdefault(
-                                sig_user_id, {}
-                            ).update(sigs)
-
-                device_display_name = device.get("device_display_name", None)
-                if device_display_name:
-                    result["device_display_name"] = device_display_name
-
-                results.append(result)
-
-            return now_stream_id, results
-
-        return now_stream_id, []
-
     async def get_users_whose_devices_changed(
         self, from_key: str, user_ids: Iterable[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 50ecddf7fa..fb3b1f94de 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,6 +14,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
@@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -33,6 +34,51 @@ if TYPE_CHECKING:
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
+    def get_e2e_device_keys_for_federation_query(self, user_id: str):
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            Deferred which resolves to (stream_id, devices)
+        """
+        return self.db_pool.runInteraction(
+            "get_e2e_device_keys_for_federation_query",
+            self._get_e2e_device_keys_for_federation_query_txn,
+            user_id,
+        )
+
+    def _get_e2e_device_keys_for_federation_query_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> Tuple[int, List[JsonDict]]:
+        now_stream_id = self.get_device_stream_token()
+
+        devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in user_devices.items():
+                result = {"device_id": device_id}
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = db_to_json(key_json)
+
+                    if "signatures" in device:
+                        for sig_user_id, sigs in device["signatures"].items():
+                            result["keys"].setdefault("signatures", {}).setdefault(
+                                sig_user_id, {}
+                            ).update(sigs)
+
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
     @trace
     async def get_e2e_device_keys_for_cs_api(
         self, query_list: List[Tuple[str, Optional[str]]]
@@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             _get_all_user_signature_changes_for_remotes_txn,
         )
 
+    @abc.abstractmethod
+    def get_device_stream_token(self) -> int:
+        """Get the current stream id from the _device_list_id_gen"""
+        ...
+
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):