diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 4f19fd35aa..4d59778863 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import DataStore
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import BaseSlavedStore, __func__
-from ._slaved_id_tracker import SlavedIdTracker
-
-class SlavedDeviceInboxStore(BaseSlavedStore):
+class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
@@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
- get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token)
- get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device)
- get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote)
- delete_messages_for_device = __func__(DataStore.delete_messages_for_device)
- delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote)
-
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ec2fd561cc..16c9a162c5 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import DataStore
-from synapse.storage.end_to_end_keys import EndToEndKeyStore
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.devices import DeviceWorkerStore
+from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import BaseSlavedStore, __func__
-from ._slaved_id_tracker import SlavedIdTracker
-
-class SlavedDeviceStore(BaseSlavedStore):
+class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
@@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore):
"DeviceListFederationStreamChangeCache", device_list_max,
)
- get_device_stream_token = __func__(DataStore.get_device_stream_token)
- get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
- get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
- _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
- _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
- mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
- _mark_as_sent_devices_by_remote_txn = (
- __func__(DataStore._mark_as_sent_devices_by_remote_txn)
- )
- count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
-
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
@@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore):
if stream_name == "device_lists":
self._device_list_id_gen.advance(token)
for row in rows:
- self._device_list_stream_cache.entity_has_changed(
- row.user_id, token
+ self._invalidate_caches_for_devices(
+ token, row.user_id, row.destination,
)
-
- if row.destination:
- self._device_list_federation_stream_cache.entity_has_changed(
- row.destination, token
- )
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
+
+ def _invalidate_caches_for_devices(self, token, user_id, destination):
+ self._device_list_stream_cache.entity_has_changed(
+ user_id, token
+ )
+
+ if destination:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ destination, token
+ )
+
+ self._get_cached_devices_for_user.invalidate((user_id,))
+ self._get_cached_user_device.invalidate_many((user_id,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index f0200c1e98..45fc913c52 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
-class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
+class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
|