diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/replication/slave/storage/deviceinbox.py | 15 | ||||
-rw-r--r-- | synapse/replication/slave/storage/devices.py | 45 | ||||
-rw-r--r-- | synapse/replication/slave/storage/push_rule.py | 2 |
3 files changed, 26 insertions, 36 deletions
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 4f19fd35aa..4d59778863 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedDeviceInboxStore(BaseSlavedStore): +class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( @@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token) - get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device) - get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote) - delete_messages_for_device = __func__(DataStore.delete_messages_for_device) - delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote) - def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() result["to_device"] = self._device_inbox_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ec2fd561cc..16c9a162c5 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore -from synapse.storage.end_to_end_keys import EndToEndKeyStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.devices import DeviceWorkerStore +from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedDeviceStore(BaseSlavedStore): +class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) @@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore): "DeviceListFederationStreamChangeCache", device_list_max, ) - get_device_stream_token = __func__(DataStore.get_device_stream_token) - get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed) - get_devices_by_remote = __func__(DataStore.get_devices_by_remote) - _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn) - _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn) - mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote) - _mark_as_sent_devices_by_remote_txn = ( - __func__(DataStore._mark_as_sent_devices_by_remote_txn) - ) - count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] - def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() result["device_lists"] = self._device_list_id_gen.get_current_token() @@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore): if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: - self._device_list_stream_cache.entity_has_changed( - row.user_id, token + self._invalidate_caches_for_devices( + token, row.user_id, row.destination, ) - - if row.destination: - self._device_list_federation_stream_cache.entity_has_changed( - row.destination, token - ) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) + + def _invalidate_caches_for_devices(self, token, user_id, destination): + self._device_list_stream_cache.entity_has_changed( + user_id, token + ) + + if destination: + self._device_list_federation_stream_cache.entity_has_changed( + destination, token + ) + + self._get_cached_devices_for_user.invalidate((user_id,)) + self._get_cached_user_device.invalidate_many((user_id,)) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index f0200c1e98..45fc913c52 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore -class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): +class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def __init__(self, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", |