diff --git a/changelog.d/17692.bugfix b/changelog.d/17692.bugfix
new file mode 100644
index 0000000000..84e0754a99
--- /dev/null
+++ b/changelog.d/17692.bugfix
@@ -0,0 +1 @@
+Make sure we get up-to-date state information when using the new Sliding Sync tables to derive room membership.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d6deb077c8..e14d711c76 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -136,6 +136,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
self._attempt_to_invalidate_cache("get_room_type", (room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
+ self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
def _invalidate_state_caches_all(self, room_id: str) -> None:
"""Invalidates caches that are based on the current state, but does
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index b0e30daee5..37c865a8e7 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -41,6 +41,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
+from synapse.storage.databases.main.events import SLIDING_SYNC_RELEVANT_STATE_SET
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import CachedFunction
@@ -271,12 +272,20 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache(
"get_rooms_for_user", (data.state_key,)
)
+ self._attempt_to_invalidate_cache(
+ "get_sliding_sync_rooms_for_user", None
+ )
elif data.type == EventTypes.RoomEncryption:
self._attempt_to_invalidate_cache(
"get_room_encryption", (data.room_id,)
)
elif data.type == EventTypes.Create:
self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
+
+ if (data.type, data.state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
+ self._attempt_to_invalidate_cache(
+ "get_sliding_sync_rooms_for_user", None
+ )
elif row.type == EventsStreamAllStateRow.TypeId:
assert isinstance(data, EventsStreamAllStateRow)
# Similar to the above, but the entire caches are invalidated. This is
@@ -285,6 +294,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_rooms_for_user", None)
self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,))
+ self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
@@ -365,6 +375,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
elif etype == EventTypes.RoomEncryption:
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
+ if (etype, state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
+ self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
+
if relates_to:
self._attempt_to_invalidate_cache(
"get_relations_for_event",
@@ -477,6 +490,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache(
"get_current_hosts_in_room_ordered", (room_id,)
)
+ self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
self._attempt_to_invalidate_cache("did_forget", None)
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index db03729cfe..1fc2d7ba1e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1365,6 +1365,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
+ self._invalidate_cache_and_stream(
+ txn, self.get_sliding_sync_rooms_for_user, (user_id,)
+ )
await self.db_pool.runInteraction("forget_membership", f)
@@ -1410,6 +1413,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
def get_sliding_sync_rooms_for_user_txn(
txn: LoggingTransaction,
) -> Dict[str, RoomsForUserSlidingSync]:
+ # XXX: If you use any new columns that can change (like from
+ # `sliding_sync_joined_rooms` or `forgotten`), make sure to bust the
+ # `get_sliding_sync_rooms_for_user` cache in the appropriate places (and add
+ # tests).
sql = """
SELECT m.room_id, m.sender, m.membership, m.membership_event_id,
r.room_version,
@@ -1432,7 +1439,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
room_version_id=row[4],
event_pos=PersistedEventPosition(row[5], row[6]),
room_type=row[7],
- is_encrypted=row[8],
+ is_encrypted=bool(row[8]),
)
for row in txn
}
diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py
index 930cb5ef45..9e23dbe522 100644
--- a/tests/rest/client/sliding_sync/test_sliding_sync.py
+++ b/tests/rest/client/sliding_sync/test_sliding_sync.py
@@ -722,43 +722,37 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.helper.join(space_room_id, user1_id, tok=user1_tok)
# Make an initial Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint,
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
+ sync_body = {
+ "lists": {
+ "all-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {},
+ },
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {
+ "is_encrypted": True,
+ "room_types": [RoomTypes.SPACE],
},
- }
- },
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- from_token = channel.json_body["pos"]
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
# Make sure the response has the lists we requested
self.assertListEqual(
- list(channel.json_body["lists"].keys()),
+ list(response_body["lists"].keys()),
["all-list", "foo-list"],
- channel.json_body["lists"].keys(),
+ response_body["lists"].keys(),
)
# Make sure the lists have the correct rooms
self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
+ list(response_body["lists"]["all-list"]["ops"]),
[
{
"op": "SYNC",
@@ -768,7 +762,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
],
)
self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
+ list(response_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
@@ -783,35 +777,30 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.helper.leave(space_room_id, user2_id, tok=user2_tok)
# Make an incremental Sliding Sync request
- channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?pos={from_token}",
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
+ sync_body = {
+ "lists": {
+ "all-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {},
+ },
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {
+ "is_encrypted": True,
+ "room_types": [RoomTypes.SPACE],
},
- }
- },
- access_token=user1_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
# Make sure the lists have the correct rooms even though we `newly_left`
self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
+ list(response_body["lists"]["all-list"]["ops"]),
[
{
"op": "SYNC",
@@ -821,7 +810,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
],
)
self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
+ list(response_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
@@ -831,6 +820,98 @@ class SlidingSyncTestCase(SlidingSyncBase):
],
)
+ def test_filter_is_encrypted_up_to_date(self) -> None:
+ """
+ Make sure we get up-to-date `is_encrypted` status for a joined room
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Update the encryption status
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # We should see the room now because it's encrypted
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_forgotten_up_to_date(self) -> None:
+ """
+ Make sure we get up-to-date `forgotten` status for rooms
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 is banned from the room (was never in the room)
+ self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {},
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # User1 forgets the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # We should no longer see the forgotten room
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
def test_sort_list(self) -> None:
"""
Test that the `lists` are sorted by `stream_ordering`
|