summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/databases/main/test_events_worker.py8
-rw-r--r--tests/storage/databases/main/test_room.py69
-rw-r--r--tests/storage/test_event_push_actions.py174
-rw-r--r--tests/storage/test_events.py7
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_receipts.py272
-rw-r--r--tests/storage/test_room.py69
-rw-r--r--tests/storage/test_roommember.py123
-rw-r--r--tests/storage/test_state.py10
9 files changed, 529 insertions, 205 deletions
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py

index 38963ce4a7..46d829b062 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py
@@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) def test_simple(self): """Test that we cache events that we pull from the DB.""" @@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): """ # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: # We keep hold of the event event though we never use it. @@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: self.get_success(self.store.get_event(self.event_id)) @@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) @contextmanager def blocking_get_event_calls( diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 9abd0cb446..1edb619630 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py
@@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + +from synapse.api.constants import RoomTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.storage.databases.main.room import _BackgroundUpdates @@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) ) self.assertEqual(room_creator_after, self.user_id) + + def test_background_add_room_type_column(self): + """Test that the background update to populate the `room_type` column in + `room_stats_state` works properly. + """ + + # Create a room without a type + room_id = self._generate_room() + + # Get event_id of the m.room.create event + event_id = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="current_state_events", + keyvalues={ + "room_id": room_id, + "type": "m.room.create", + }, + retcol="event_id", + ) + ) + + # Fake a room creation event with a room type + event = { + "content": { + "creator": "@user:server.org", + "room_version": "9", + "type": RoomTypes.SPACE, + }, + "type": "m.room.create", + } + self.get_success( + self.store.db_pool.simple_update( + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": json.dumps(event)}, + desc="test", + ) + ) + + # Insert and run the background update + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Make sure the background update filled in the room type + room_type_after = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcol="room_type", + allow_none=True, + ) + ) + self.assertEqual(room_type_after, RoomTypes.SPACE) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 0f9add4841..fc43d7edd1 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -12,143 +12,175 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.util import Clock from tests.unittest import HomeserverTestCase USER_ID = "@user:example.com" -PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] -HIGHLIGHT = [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, -] - class EventPushActionsStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self.persist_events_store = hs.get_datastores().persist_events + persist_events_store = hs.get_datastores().persist_events + assert persist_events_store is not None + self.persist_events_store = persist_events_store - def test_get_unread_push_actions_for_user_in_range_for_http(self): + def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None: self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_http( USER_ID, 0, 1000, 20 ) ) - def test_get_unread_push_actions_for_user_in_range_for_email(self): + def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None: self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_email( USER_ID, 0, 1000, 20 ) ) - def test_count_aggregation(self): - room_id = "!foo:example.com" - user_id = "@user1235:example.com" + def test_count_aggregation(self) -> None: + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) - def _assert_counts(noitf_count, highlight_count): + last_event_id: str + + def _assert_counts(noitf_count: int, highlight_count: int) -> None: counts = self.get_success( self.store.db_pool.runInteraction( - "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, ) ) self.assertEqual( counts, NotifCounts( notify_count=noitf_count, - unread_count=0, # Unread counts are tested in the sync tests. + unread_count=0, highlight_count=highlight_count, ), ) - def _inject_actions(stream, action): - event = Mock() - event.room_id = room_id - event.event_id = "$test:example.com" - event.internal_metadata.stream_ordering = stream - event.internal_metadata.is_outlier.return_value = False - event.depth = stream - - self.get_success( - self.store.add_push_actions_to_staging( - event.event_id, - {user_id: action}, - False, - ) - ) - self.get_success( - self.store.db_pool.runInteraction( - "", - self.persist_events_store._set_push_actions_for_event_and_users_txn, - [(event, None)], - [(event, None)], - ) + def _create_event(highlight: bool = False) -> str: + result = self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": user_id if highlight else "msg"}, + tok=other_token, ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id - def _rotate(stream): - self.get_success( - self.store.db_pool.runInteraction( - "", self.store._rotate_notifs_before_txn, stream - ) - ) + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) - def _mark_read(stream, depth): + def _mark_read(event_id: str) -> None: self.get_success( - self.store.db_pool.runInteraction( - "", - self.store._remove_old_push_actions_before_txn, + self.store.insert_receipt( room_id, - user_id, - stream, + "m.read", + user_id=user_id, + event_ids=[event_id], + data={}, ) ) _assert_counts(0, 0) - _inject_actions(1, PlAIN_NOTIF) + _create_event() _assert_counts(1, 0) - _rotate(2) + _rotate() _assert_counts(1, 0) - _inject_actions(3, PlAIN_NOTIF) + event_id = _create_event() _assert_counts(2, 0) - _rotate(4) + _rotate() _assert_counts(2, 0) - _inject_actions(5, PlAIN_NOTIF) - _mark_read(3, 3) + _create_event() + _mark_read(event_id) _assert_counts(1, 0) - _mark_read(5, 5) + _mark_read(last_event_id) _assert_counts(0, 0) - _inject_actions(6, PlAIN_NOTIF) - _rotate(7) + _create_event() + _rotate() + _assert_counts(1, 0) - self.get_success( - self.store.db_pool.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" + # Delete old event push actions, this should not affect the (summarised) count. + # + # All event push actions are kept for 24 hours, so need to move forward + # in time. + self.pump(60 * 60 * 24) + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + # Double check that the event push actions have been cleared (i.e. that + # any results *must* come from the summary). + result = self.get_success( + self.store.db_pool.simple_select_list( + table="event_push_actions", + keyvalues={"1": 1}, + retcols=("event_id",), + desc="", ) ) - + self.assertEqual(result, []) _assert_counts(1, 0) - _mark_read(7, 7) + _mark_read(last_event_id) _assert_counts(0, 0) - _inject_actions(8, HIGHLIGHT) + event_id = _create_event(True) _assert_counts(1, 1) - _rotate(9) + _rotate() _assert_counts(1, 1) - _rotate(10) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0) + + _create_event(True) _assert_counts(1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0) + _rotate() + _assert_counts(0, 0) - def test_find_first_stream_ordering_after_ts(self): - def add_event(so, ts): + def test_find_first_stream_ordering_after_ts(self) -> None: + def add_event(so: int, ts: int) -> None: self.get_success( self.store.db_pool.simple_insert( "events", diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 2ff88e64a5..3ce4f35cb7 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py
@@ -70,7 +70,11 @@ class ExtremPruneTestCase(HomeserverTestCase): def persist_event(self, event, state=None): """Persist the event, with optional state""" context = self.get_success( - self.state.compute_event_context(event, state_ids_before_event=state) + self.state.compute_event_context( + event, + state_ids_before_event=state, + partial_state=None if state is None else False, + ) ) self.get_success(self._persistence.persist_event(event, context)) @@ -148,6 +152,7 @@ class ExtremPruneTestCase(HomeserverTestCase): self.state.compute_event_context( remote_event_2, state_ids_before_event=state_before_gap, + partial_state=False, ) ) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 8dfaa0559b..9c1182ed16 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py
@@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase): ) # The events aren't found. - self.store._invalidate_get_event_cache(create_event.event_id) + self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py new file mode 100644
index 0000000000..c89bfff241 --- /dev/null +++ b/tests/storage/test_receipts.py
@@ -0,0 +1,272 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.api.constants import ReceiptTypes +from synapse.types import UserID, create_requester + +from tests.test_utils.event_injection import create_event +from tests.unittest import HomeserverTestCase + +OTHER_USER_ID = "@other:test" +OUR_USER_ID = "@our:test" + + +class ReceiptTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, homeserver) -> None: + super().prepare(reactor, clock, homeserver) + + self.store = homeserver.get_datastores().main + + self.room_creator = homeserver.get_room_creation_handler() + self.persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) + + # Create a test user + self.ourUser = UserID.from_string(OUR_USER_ID) + self.ourRequester = create_requester(self.ourUser) + + # Create a second test user + self.otherUser = UserID.from_string(OTHER_USER_ID) + self.otherRequester = create_requester(self.otherUser) + + # Create a test room + info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) + self.room_id1 = info["room_id"] + + # Create a second test room + info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) + self.room_id2 = info["room_id"] + + # Join the second user to the first room + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id1, + type="m.room.member", + sender=self.otherRequester.user.to_string(), + state_key=self.otherRequester.user.to_string(), + content={"membership": "join"}, + ) + ) + self.get_success( + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) + ) + + # Join the second user to the second room + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id2, + type="m.room.member", + sender=self.otherRequester.user.to_string(), + state_key=self.otherRequester.user.to_string(), + content={"membership": "join"}, + ) + ) + self.get_success( + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) + ) + + def test_return_empty_with_no_data(self) -> None: + res = self.get_success( + self.store.get_receipts_for_user( + OUR_USER_ID, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ], + ) + ) + self.assertEqual(res, {}) + + res = self.get_success( + self.store.get_receipts_for_user_with_orderings( + OUR_USER_ID, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ], + ) + ) + self.assertEqual(res, {}) + + res = self.get_success( + self.store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id1, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ], + ) + ) + self.assertEqual(res, None) + + def test_get_receipts_for_user(self) -> None: + # Send some events into the first room + event1_1_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + event1_2_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + + # Send public read receipt for the first event + self.get_success( + self.store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + ) + ) + # Send private read receipt for the second event + self.get_success( + self.store.insert_receipt( + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + ) + ) + + # Test we get the latest event when we want both private and public receipts + res = self.get_success( + self.store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Test we get the older event when we want only public receipt + res = self.get_success( + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + ) + self.assertEqual(res, {self.room_id1: event1_1_id}) + + # Test we get the latest event when we want only the public receipt + res = self.get_success( + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Test receipt updating + self.get_success( + self.store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + ) + ) + res = self.get_success( + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Send some events into the second room + event2_1_id = self.create_and_send_event( + self.room_id2, UserID.from_string(OTHER_USER_ID) + ) + + # Test new room is reflected in what the method returns + self.get_success( + self.store.insert_receipt( + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + ) + ) + res = self.get_success( + self.store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) + + def test_get_last_receipt_event_id_for_user(self) -> None: + # Send some events into the first room + event1_1_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + event1_2_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + + # Send public read receipt for the first event + self.get_success( + self.store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + ) + ) + # Send private read receipt for the second event + self.get_success( + self.store.insert_receipt( + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + ) + ) + + # Test we get the latest event when we want both private and public receipts + res = self.get_success( + self.store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id1, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, event1_2_id) + + # Test we get the older event when we want only public receipt + res = self.get_success( + self.store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] + ) + ) + self.assertEqual(res, event1_1_id) + + # Test we get the latest event when we want only the private receipt + res = self.get_success( + self.store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, event1_2_id) + + # Test receipt updating + self.get_success( + self.store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + ) + ) + res = self.get_success( + self.store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] + ) + ) + self.assertEqual(res, event1_2_id) + + # Send some events into the second room + event2_1_id = self.create_and_send_event( + self.room_id2, UserID.from_string(OTHER_USER_ID) + ) + + # Test new room is reflected in what the method returns + self.get_success( + self.store.insert_receipt( + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + ) + ) + res = self.get_success( + self.store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id2, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, event2_1_id) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3c79dabc9f..3405efb6a8 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID @@ -65,71 +64,3 @@ class RoomStoreTestCase(HomeserverTestCase): self.assertIsNone( (self.get_success(self.store.get_room_with_stats("!uknown:test"))), ) - - -class RoomEventsStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): - # Room events need the full datastore, for persist_event() and - # get_room_state() - self.store = hs.get_datastores().main - self._storage_controllers = hs.get_storage_controllers() - self.event_factory = hs.get_event_factory() - - self.room = RoomID.from_string("!abcde:test") - - self.get_success( - self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, - ) - ) - - def inject_room_event(self, **kwargs): - self.get_success( - self._storage_controllers.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) - ) - ) - - def STALE_test_room_name(self): - name = "A-Room-Name" - - self.inject_room_event( - etype=EventTypes.Name, name=name, content={"name": name}, depth=1 - ) - - state = self.get_success( - self._storage_controllers.state.get_current_state( - room_id=self.room.to_string() - ) - ) - - self.assertEqual(1, len(state)) - self.assertObjectHasAttributes( - {"type": "m.room.name", "room_id": self.room.to_string(), "name": name}, - state[0], - ) - - def STALE_test_room_topic(self): - topic = "A place for things" - - self.inject_room_event( - etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 - ) - - state = self.get_success( - self._storage_controllers.state.get_current_state( - room_id=self.room.to_string() - ) - ) - - self.assertEqual(1, len(state)) - self.assertObjectHasAttributes( - {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic}, - state[0], - ) - - # Not testing the various 'level' methods for now because there's lots - # of them and need coalescing; see JIRA SPEC-11 diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1218786d79..8794401823 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -110,60 +110,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) - def test_get_joined_users_from_context(self) -> None: - room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = self.get_success( - event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN - ) - ) - - # first, create a regular event - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, - ) - ) - - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - - # Regression test for #7376: create a state event whose key matches bob's - # user_id, but which is *not* a membership event, and persist that; then check - # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = self.get_success( - event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, - ) - ) - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, - ) - ) - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) @@ -212,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Check that alice's display name is now None self.assertEqual(row[0]["display_name"], None) + def test_room_is_locally_forgotten(self) -> None: + """Test that when the last local user has forgotten a room it is known as forgotten.""" + # join two local and one remote user + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.get_success( + event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join") + ) + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_charlie.to_string(), "join" + ) + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # local users leave the room and the room is not forgotten + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "leave" + ) + ) + self.get_success( + event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave") + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # first user forgets the room, room is not forgotten + self.get_success(self.store.forget(self.u_alice, self.room)) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # second (last local) user forgets the room and the room is forgotten + self.get_success(self.store.forget(self.u_bob, self.room)) + self.assertTrue( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + def test_join_locally_forgotten_room(self) -> None: + """Tests if a user joins a forgotten room the room is not forgotten anymore.""" + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # after leaving and forget the room, it is forgotten + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "leave" + ) + ) + self.get_success(self.store.forget(self.u_alice, self.room)) + self.assertTrue( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + + # after rejoin the room is not forgotten anymore + self.get_success( + event_injection.inject_member_event( + self.hs, self.room, self.u_alice, "join" + ) + ) + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room)) + ) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8043bdbde2..5564161750 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -369,8 +369,8 @@ class StateStoreTestCase(HomeserverTestCase): state_dict_ids = cache_entry.value self.assertEqual(cache_entry.full, False) - self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)}) - self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) + self.assertEqual(cache_entry.known_absent, set()) + self.assertDictEqual(state_dict_ids, {}) ############################################ # test that things work with a partial cache @@ -387,7 +387,7 @@ class StateStoreTestCase(HomeserverTestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + self.assertDictEqual({}, state_dict) room_id = self.room.to_string() (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( @@ -412,7 +412,7 @@ class StateStoreTestCase(HomeserverTestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + self.assertDictEqual({}, state_dict) (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, @@ -443,7 +443,7 @@ class StateStoreTestCase(HomeserverTestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + self.assertDictEqual({}, state_dict) (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache,