diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f19540d6bb..18076e0f3b 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -34,6 +34,9 @@ class BaseSlavedStore(SQLBaseStore):
else:
self._cache_id_gen = None
+ self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
+ self.http_client = hs.get_simple_http_client()
+
def stream_positions(self):
pos = {}
if self._cache_id_gen:
@@ -54,3 +57,19 @@ class BaseSlavedStore(SQLBaseStore):
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
+
+ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ txn.call_after(cache_func.invalidate, keys)
+ txn.call_after(self._send_invalidation_poke, cache_func, keys)
+
+ @defer.inlineCallbacks
+ def _send_invalidation_poke(self, cache_func, keys):
+ try:
+ yield self.http_client.post_json_get_json(self.expire_cache_url, {
+ "invalidate": [{
+ "name": cache_func.__name__,
+ "keys": list(keys),
+ }]
+ })
+ except:
+ logger.exception("Failed to poke on expire_cache")
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 3bfd5e8213..cc860f9f9b 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -29,10 +29,16 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token()
)
+ self._device_federation_outbox_stream_cache = StreamChangeCache(
+ "DeviceFederationOutboxStreamChangeCache",
+ self._device_inbox_id_gen.get_current_token()
+ )
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
+ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
+ delete_device_msgs_for_remote = DataStore.delete_device_msgs_for_remote.__func__
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
@@ -45,9 +51,15 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
- user_id = row[1]
- self._device_inbox_stream_cache.entity_has_changed(
- user_id, stream_id
- )
+ entity = row[1]
+
+ if entity.startswith("@"):
+ self._device_inbox_stream_cache.entity_has_changed(
+ entity, stream_id
+ )
+ else:
+ self._device_federation_outbox_stream_cache.entity_has_changed(
+ entity, stream_id
+ )
return super(SlavedDeviceInboxStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 0c26e96e98..64f18bbb3e 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -26,6 +26,11 @@ from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@@ -180,6 +185,11 @@ class SlavedEventStore(BaseSlavedStore):
EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
)
+ get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
+
+ get_federation_out_pos = DataStore.get_federation_out_pos.__func__
+ update_federation_out_pos = DataStore.update_federation_out_pos.__func__
+
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
@@ -194,6 +204,10 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("events")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
+
+ if stream["rows"]:
+ logger.info("Got %d event rows", len(stream["rows"]))
+
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 6f2ba98af5..fbb58f35da 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
@@ -22,9 +21,10 @@ from synapse.storage.transactions import TransactionStore
class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[
"get_destination_retry_timings"
- ].orig
+ ]
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
+ set_destination_retry_timings = DataStore.set_destination_retry_timings.__func__
+ _set_destination_retry_timings = DataStore._set_destination_retry_timings.__func__
- # For now, don't record the destination rety timings
- def set_destination_retry_timings(*args, **kwargs):
- return defer.succeed(None)
+ prep_send_transaction = DataStore.prep_send_transaction.__func__
+ delivered_txn = DataStore.delivered_txn.__func__
|