diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ec2fd561cc..16c9a162c5 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import DataStore
-from synapse.storage.end_to_end_keys import EndToEndKeyStore
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.devices import DeviceWorkerStore
+from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import BaseSlavedStore, __func__
-from ._slaved_id_tracker import SlavedIdTracker
-
-class SlavedDeviceStore(BaseSlavedStore):
+class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
@@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore):
"DeviceListFederationStreamChangeCache", device_list_max,
)
- get_device_stream_token = __func__(DataStore.get_device_stream_token)
- get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
- get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
- _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
- _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
- mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
- _mark_as_sent_devices_by_remote_txn = (
- __func__(DataStore._mark_as_sent_devices_by_remote_txn)
- )
- count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
-
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
@@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore):
if stream_name == "device_lists":
self._device_list_id_gen.advance(token)
for row in rows:
- self._device_list_stream_cache.entity_has_changed(
- row.user_id, token
+ self._invalidate_caches_for_devices(
+ token, row.user_id, row.destination,
)
-
- if row.destination:
- self._device_list_federation_stream_cache.entity_has_changed(
- row.destination, token
- )
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
+
+ def _invalidate_caches_for_devices(self, token, user_id, destination):
+ self._device_list_stream_cache.entity_has_changed(
+ user_id, token
+ )
+
+ if destination:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ destination, token
+ )
+
+ self._get_cached_devices_for_user.invalidate((user_id,))
+ self._get_cached_user_device.invalidate_many((user_id,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
|