diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 77c64722c7..efbd87918e 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
result["tag_account_data"] = position
return result
- def process_replication(self, result):
- stream = result.get("user_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id, data_type = row[:3]
- self.get_global_account_data_by_type_for_user.invalidate(
- (data_type, user_id,)
- )
- self.get_account_data_for_user.invalidate((user_id,))
+ def process_replication_rows(self, stream_name, token, rows):
+ if stream_name == "tag_account_data":
+ self._account_data_id_gen.advance(token)
+ for row in rows:
+ self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- stream = result.get("room_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
- self.get_account_data_for_user.invalidate((user_id,))
+ elif stream_name == "account_data":
+ self._account_data_id_gen.advance(token)
+ for row in rows:
+ if not row.room_id:
+ self.get_global_account_data_by_type_for_user.invalidate(
+ (row.data_type, row.user_id,)
+ )
+ self.get_account_data_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
- user_id, position
+ row.user_id, token
)
-
- stream = result.get("tag_account_data")
- if stream:
- self._account_data_id_gen.advance(int(stream["position"]))
- for row in stream["rows"]:
- position, user_id = row[:2]
- self.get_tags_for_user.invalidate((user_id,))
- self._account_data_stream_cache.entity_has_changed(
- user_id, position
- )
-
- return super(SlavedAccountDataStore, self).process_replication(result)
+ return super(SlavedAccountDataStore, self).process_replication_rows(
+ stream_name, token, rows
+ )
|