diff options
Diffstat (limited to 'synapse/storage/end_to_end_keys.py')
-rw-r--r-- | synapse/storage/end_to_end_keys.py | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2cebb203c6..523b4360c3 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from six import iteritems + +from canonicaljson import encode_canonical_json, json + from twisted.internet import defer from synapse.util.caches.descriptors import cached -from canonicaljson import encode_canonical_json -import ujson as json - from ._base import SQLBaseStore @@ -63,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore): ) @defer.inlineCallbacks - def get_e2e_device_keys(self, query_list, include_all_devices=False): + def get_e2e_device_keys( + self, query_list, include_all_devices=False, + include_deleted_devices=False, + ): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. include_all_devices (bool): whether to include entries for devices that don't have device keys + include_deleted_devices (bool): whether to include null entries for + devices which no longer exist (but were in the query_list). + This option only takes effect if include_all_devices is true. Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". @@ -78,19 +85,28 @@ class EndToEndKeyStore(SQLBaseStore): results = yield self.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, - query_list, include_all_devices, + query_list, include_all_devices, include_deleted_devices, ) - for user_id, device_keys in results.iteritems(): - for device_id, device_info in device_keys.iteritems(): + for user_id, device_keys in iteritems(results): + for device_id, device_info in iteritems(device_keys): device_info["keys"] = json.loads(device_info.pop("key_json")) defer.returnValue(results) - def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): + def _get_e2e_device_keys_txn( + self, txn, query_list, include_all_devices=False, + include_deleted_devices=False, + ): query_clauses = [] query_params = [] + if include_all_devices is False: + include_deleted_devices = False + + if include_deleted_devices: + deleted_devices = set(query_list) + for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) @@ -118,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore): result = {} for row in rows: + if include_deleted_devices: + deleted_devices.remove((row["user_id"], row["device_id"])) result.setdefault(row["user_id"], {})[row["device_id"]] = row + if include_deleted_devices: + for user_id, device_id in deleted_devices: + result.setdefault(user_id, {})[device_id] = None + return result @defer.inlineCallbacks |