diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index eaead50800..f4bf159bb5 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -143,6 +143,10 @@ class DeviceHandler(BaseHandler):
delete_refresh_tokens=True,
)
+ yield self.store.delete_e2e_keys_by_device(
+ user_id=user_id, device_id=device_id
+ )
+
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 0bf32a089b..4629f4bfde 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -86,10 +86,6 @@ class KeyUploadServlet(RestServlet):
raise synapse.api.errors.SynapseError(
400, "Can only upload keys for current device"
)
-
- self.device_handler.check_device_registered(
- user_id, device_id, "unknown device"
- )
else:
device_id = requester.device_id
@@ -131,6 +127,15 @@ class KeyUploadServlet(RestServlet):
user_id, device_id, time_now, key_list
)
+ # the device should have been registered already, but it may have been
+ # deleted due to a race with a DELETE request. Or we may be using an
+ # old access_token without an associated device_id. Either way, we
+ # need to double-check the device is registered to avoid ending up with
+ # keys without a corresponding device.
+ self.device_handler.check_device_registered(
+ user_id, device_id, "unknown device"
+ )
+
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2e89066515..62b7790e91 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import twisted.internet.defer
+
from ._base import SQLBaseStore
@@ -123,3 +125,16 @@ class EndToEndKeyStore(SQLBaseStore):
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
+
+ @twisted.internet.defer.inlineCallbacks
+ def delete_e2e_keys_by_device(self, user_id, device_id):
+ yield self._simple_delete(
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_device_keys_by_device"
+ )
+ yield self._simple_delete(
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="delete_e2e_one_time_keys_by_device"
+ )
|