diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 38963ce4a7..46d829b062 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
def test_simple(self):
"""Test that we cache events that we pull from the DB."""
@@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
"""
# Reset the event cache
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
with LoggingContext("test") as ctx:
# We keep hold of the event event though we never use it.
@@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
# Reset the event cache
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
with LoggingContext("test") as ctx:
self.get_success(self.store.get_event(self.event_id))
@@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.event_id = res["event_id"]
# Reset the event cache so the tests start with it empty
- self.store._get_event_cache.clear()
+ self.get_success(self.store._get_event_cache.clear())
@contextmanager
def blocking_get_event_calls(
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 9abd0cb446..1edb619630 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+
+from synapse.api.constants import RoomTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.storage.databases.main.room import _BackgroundUpdates
@@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_creator_after, self.user_id)
+
+ def test_background_add_room_type_column(self):
+ """Test that the background update to populate the `room_type` column in
+ `room_stats_state` works properly.
+ """
+
+ # Create a room without a type
+ room_id = self._generate_room()
+
+ # Get event_id of the m.room.create event
+ event_id = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="current_state_events",
+ keyvalues={
+ "room_id": room_id,
+ "type": "m.room.create",
+ },
+ retcol="event_id",
+ )
+ )
+
+ # Fake a room creation event with a room type
+ event = {
+ "content": {
+ "creator": "@user:server.org",
+ "room_version": "9",
+ "type": RoomTypes.SPACE,
+ },
+ "type": "m.room.create",
+ }
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": json.dumps(event)},
+ desc="test",
+ )
+ )
+
+ # Insert and run the background update
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ self.wait_for_background_updates()
+
+ # Make sure the background update filled in the room type
+ room_type_after = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ retcol="room_type",
+ allow_none=True,
+ )
+ )
+ self.assertEqual(room_type_after, RoomTypes.SPACE)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 0f9add4841..fc43d7edd1 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -12,143 +12,175 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
-PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
-HIGHLIGHT = [
- "notify",
- {"set_tweak": "sound", "value": "default"},
- {"set_tweak": "highlight"},
-]
-
class EventPushActionsStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self.persist_events_store = hs.get_datastores().persist_events
+ persist_events_store = hs.get_datastores().persist_events
+ assert persist_events_store is not None
+ self.persist_events_store = persist_events_store
- def test_get_unread_push_actions_for_user_in_range_for_http(self):
+ def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
- def test_get_unread_push_actions_for_user_in_range_for_email(self):
+ def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
- def test_count_aggregation(self):
- room_id = "!foo:example.com"
- user_id = "@user1235:example.com"
+ def test_count_aggregation(self) -> None:
+ # Create a user to receive notifications and send receipts.
+ user_id = self.register_user("user1235", "pass")
+ token = self.login("user1235", "pass")
+
+ # And another users to send events.
+ other_id = self.register_user("other", "pass")
+ other_token = self.login("other", "pass")
+
+ # Create a room and put both users in it.
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ self.helper.join(room_id, other_id, tok=other_token)
- def _assert_counts(noitf_count, highlight_count):
+ last_event_id: str
+
+ def _assert_counts(noitf_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
- "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+ "get-unread-counts",
+ self.store._get_unread_counts_by_receipt_txn,
+ room_id,
+ user_id,
)
)
self.assertEqual(
counts,
NotifCounts(
notify_count=noitf_count,
- unread_count=0, # Unread counts are tested in the sync tests.
+ unread_count=0,
highlight_count=highlight_count,
),
)
- def _inject_actions(stream, action):
- event = Mock()
- event.room_id = room_id
- event.event_id = "$test:example.com"
- event.internal_metadata.stream_ordering = stream
- event.internal_metadata.is_outlier.return_value = False
- event.depth = stream
-
- self.get_success(
- self.store.add_push_actions_to_staging(
- event.event_id,
- {user_id: action},
- False,
- )
- )
- self.get_success(
- self.store.db_pool.runInteraction(
- "",
- self.persist_events_store._set_push_actions_for_event_and_users_txn,
- [(event, None)],
- [(event, None)],
- )
+ def _create_event(highlight: bool = False) -> str:
+ result = self.helper.send_event(
+ room_id,
+ type="m.room.message",
+ content={"msgtype": "m.text", "body": user_id if highlight else "msg"},
+ tok=other_token,
)
+ nonlocal last_event_id
+ last_event_id = result["event_id"]
+ return last_event_id
- def _rotate(stream):
- self.get_success(
- self.store.db_pool.runInteraction(
- "", self.store._rotate_notifs_before_txn, stream
- )
- )
+ def _rotate() -> None:
+ self.get_success(self.store._rotate_notifs())
- def _mark_read(stream, depth):
+ def _mark_read(event_id: str) -> None:
self.get_success(
- self.store.db_pool.runInteraction(
- "",
- self.store._remove_old_push_actions_before_txn,
+ self.store.insert_receipt(
room_id,
- user_id,
- stream,
+ "m.read",
+ user_id=user_id,
+ event_ids=[event_id],
+ data={},
)
)
_assert_counts(0, 0)
- _inject_actions(1, PlAIN_NOTIF)
+ _create_event()
_assert_counts(1, 0)
- _rotate(2)
+ _rotate()
_assert_counts(1, 0)
- _inject_actions(3, PlAIN_NOTIF)
+ event_id = _create_event()
_assert_counts(2, 0)
- _rotate(4)
+ _rotate()
_assert_counts(2, 0)
- _inject_actions(5, PlAIN_NOTIF)
- _mark_read(3, 3)
+ _create_event()
+ _mark_read(event_id)
_assert_counts(1, 0)
- _mark_read(5, 5)
+ _mark_read(last_event_id)
_assert_counts(0, 0)
- _inject_actions(6, PlAIN_NOTIF)
- _rotate(7)
+ _create_event()
+ _rotate()
+ _assert_counts(1, 0)
- self.get_success(
- self.store.db_pool.simple_delete(
- table="event_push_actions", keyvalues={"1": 1}, desc=""
+ # Delete old event push actions, this should not affect the (summarised) count.
+ #
+ # All event push actions are kept for 24 hours, so need to move forward
+ # in time.
+ self.pump(60 * 60 * 24)
+ self.get_success(self.store._remove_old_push_actions_that_have_rotated())
+ # Double check that the event push actions have been cleared (i.e. that
+ # any results *must* come from the summary).
+ result = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="event_push_actions",
+ keyvalues={"1": 1},
+ retcols=("event_id",),
+ desc="",
)
)
-
+ self.assertEqual(result, [])
_assert_counts(1, 0)
- _mark_read(7, 7)
+ _mark_read(last_event_id)
_assert_counts(0, 0)
- _inject_actions(8, HIGHLIGHT)
+ event_id = _create_event(True)
_assert_counts(1, 1)
- _rotate(9)
+ _rotate()
_assert_counts(1, 1)
- _rotate(10)
+
+ # Check that adding another notification and rotating after highlight
+ # works.
+ _create_event()
+ _rotate()
+ _assert_counts(2, 1)
+
+ # Check that sending read receipts at different points results in the
+ # right counts.
+ _mark_read(event_id)
+ _assert_counts(1, 0)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0)
+
+ _create_event(True)
_assert_counts(1, 1)
+ _mark_read(last_event_id)
+ _assert_counts(0, 0)
+ _rotate()
+ _assert_counts(0, 0)
- def test_find_first_stream_ordering_after_ts(self):
- def add_event(so, ts):
+ def test_find_first_stream_ordering_after_ts(self) -> None:
+ def add_event(so: int, ts: int) -> None:
self.get_success(
self.store.db_pool.simple_insert(
"events",
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 2ff88e64a5..3ce4f35cb7 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -70,7 +70,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
- self.state.compute_event_context(event, state_ids_before_event=state)
+ self.state.compute_event_context(
+ event,
+ state_ids_before_event=state,
+ partial_state=None if state is None else False,
+ )
)
self.get_success(self._persistence.persist_event(event, context))
@@ -148,6 +152,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
self.state.compute_event_context(
remote_event_2,
state_ids_before_event=state_before_gap,
+ partial_state=False,
)
)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 8dfaa0559b..9c1182ed16 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase):
)
# The events aren't found.
- self.store._invalidate_get_event_cache(create_event.event_id)
+ self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
new file mode 100644
index 0000000000..c89bfff241
--- /dev/null
+++ b/tests/storage/test_receipts.py
@@ -0,0 +1,272 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.constants import ReceiptTypes
+from synapse.types import UserID, create_requester
+
+from tests.test_utils.event_injection import create_event
+from tests.unittest import HomeserverTestCase
+
+OTHER_USER_ID = "@other:test"
+OUR_USER_ID = "@our:test"
+
+
+class ReceiptTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver) -> None:
+ super().prepare(reactor, clock, homeserver)
+
+ self.store = homeserver.get_datastores().main
+
+ self.room_creator = homeserver.get_room_creation_handler()
+ self.persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
+
+ # Create a test user
+ self.ourUser = UserID.from_string(OUR_USER_ID)
+ self.ourRequester = create_requester(self.ourUser)
+
+ # Create a second test user
+ self.otherUser = UserID.from_string(OTHER_USER_ID)
+ self.otherRequester = create_requester(self.otherUser)
+
+ # Create a test room
+ info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
+ self.room_id1 = info["room_id"]
+
+ # Create a second test room
+ info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
+ self.room_id2 = info["room_id"]
+
+ # Join the second user to the first room
+ memberEvent, memberEventContext = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id1,
+ type="m.room.member",
+ sender=self.otherRequester.user.to_string(),
+ state_key=self.otherRequester.user.to_string(),
+ content={"membership": "join"},
+ )
+ )
+ self.get_success(
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
+ )
+
+ # Join the second user to the second room
+ memberEvent, memberEventContext = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id2,
+ type="m.room.member",
+ sender=self.otherRequester.user.to_string(),
+ state_key=self.otherRequester.user.to_string(),
+ content={"membership": "join"},
+ )
+ )
+ self.get_success(
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
+ )
+
+ def test_return_empty_with_no_data(self) -> None:
+ res = self.get_success(
+ self.store.get_receipts_for_user(
+ OUR_USER_ID,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
+ )
+ )
+ self.assertEqual(res, {})
+
+ res = self.get_success(
+ self.store.get_receipts_for_user_with_orderings(
+ OUR_USER_ID,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
+ )
+ )
+ self.assertEqual(res, {})
+
+ res = self.get_success(
+ self.store.get_last_receipt_event_id_for_user(
+ OUR_USER_ID,
+ self.room_id1,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
+ )
+ )
+ self.assertEqual(res, None)
+
+ def test_get_receipts_for_user(self) -> None:
+ # Send some events into the first room
+ event1_1_id = self.create_and_send_event(
+ self.room_id1, UserID.from_string(OTHER_USER_ID)
+ )
+ event1_2_id = self.create_and_send_event(
+ self.room_id1, UserID.from_string(OTHER_USER_ID)
+ )
+
+ # Send public read receipt for the first event
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ )
+ )
+ # Send private read receipt for the second event
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ )
+ )
+
+ # Test we get the latest event when we want both private and public receipts
+ res = self.get_success(
+ self.store.get_receipts_for_user(
+ OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+ )
+ )
+ self.assertEqual(res, {self.room_id1: event1_2_id})
+
+ # Test we get the older event when we want only public receipt
+ res = self.get_success(
+ self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
+ )
+ self.assertEqual(res, {self.room_id1: event1_1_id})
+
+ # Test we get the latest event when we want only the public receipt
+ res = self.get_success(
+ self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE])
+ )
+ self.assertEqual(res, {self.room_id1: event1_2_id})
+
+ # Test receipt updating
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+ )
+ )
+ res = self.get_success(
+ self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
+ )
+ self.assertEqual(res, {self.room_id1: event1_2_id})
+
+ # Send some events into the second room
+ event2_1_id = self.create_and_send_event(
+ self.room_id2, UserID.from_string(OTHER_USER_ID)
+ )
+
+ # Test new room is reflected in what the method returns
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+ )
+ )
+ res = self.get_success(
+ self.store.get_receipts_for_user(
+ OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+ )
+ )
+ self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
+
+ def test_get_last_receipt_event_id_for_user(self) -> None:
+ # Send some events into the first room
+ event1_1_id = self.create_and_send_event(
+ self.room_id1, UserID.from_string(OTHER_USER_ID)
+ )
+ event1_2_id = self.create_and_send_event(
+ self.room_id1, UserID.from_string(OTHER_USER_ID)
+ )
+
+ # Send public read receipt for the first event
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+ )
+ )
+ # Send private read receipt for the second event
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+ )
+ )
+
+ # Test we get the latest event when we want both private and public receipts
+ res = self.get_success(
+ self.store.get_last_receipt_event_id_for_user(
+ OUR_USER_ID,
+ self.room_id1,
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ )
+ )
+ self.assertEqual(res, event1_2_id)
+
+ # Test we get the older event when we want only public receipt
+ res = self.get_success(
+ self.store.get_last_receipt_event_id_for_user(
+ OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
+ )
+ )
+ self.assertEqual(res, event1_1_id)
+
+ # Test we get the latest event when we want only the private receipt
+ res = self.get_success(
+ self.store.get_last_receipt_event_id_for_user(
+ OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
+ )
+ )
+ self.assertEqual(res, event1_2_id)
+
+ # Test receipt updating
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+ )
+ )
+ res = self.get_success(
+ self.store.get_last_receipt_event_id_for_user(
+ OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
+ )
+ )
+ self.assertEqual(res, event1_2_id)
+
+ # Send some events into the second room
+ event2_1_id = self.create_and_send_event(
+ self.room_id2, UserID.from_string(OTHER_USER_ID)
+ )
+
+ # Test new room is reflected in what the method returns
+ self.get_success(
+ self.store.insert_receipt(
+ self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+ )
+ )
+ res = self.get_success(
+ self.store.get_last_receipt_event_id_for_user(
+ OUR_USER_ID,
+ self.room_id2,
+ [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+ )
+ )
+ self.assertEqual(res, event2_1_id)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3c79dabc9f..3405efb6a8 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
@@ -65,71 +64,3 @@ class RoomStoreTestCase(HomeserverTestCase):
self.assertIsNone(
(self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
-
-
-class RoomEventsStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
- # Room events need the full datastore, for persist_event() and
- # get_room_state()
- self.store = hs.get_datastores().main
- self._storage_controllers = hs.get_storage_controllers()
- self.event_factory = hs.get_event_factory()
-
- self.room = RoomID.from_string("!abcde:test")
-
- self.get_success(
- self.store.store_room(
- self.room.to_string(),
- room_creator_user_id="@creator:text",
- is_public=True,
- room_version=RoomVersions.V1,
- )
- )
-
- def inject_room_event(self, **kwargs):
- self.get_success(
- self._storage_controllers.persistence.persist_event(
- self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
- )
- )
-
- def STALE_test_room_name(self):
- name = "A-Room-Name"
-
- self.inject_room_event(
- etype=EventTypes.Name, name=name, content={"name": name}, depth=1
- )
-
- state = self.get_success(
- self._storage_controllers.state.get_current_state(
- room_id=self.room.to_string()
- )
- )
-
- self.assertEqual(1, len(state))
- self.assertObjectHasAttributes(
- {"type": "m.room.name", "room_id": self.room.to_string(), "name": name},
- state[0],
- )
-
- def STALE_test_room_topic(self):
- topic = "A place for things"
-
- self.inject_room_event(
- etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
- )
-
- state = self.get_success(
- self._storage_controllers.state.get_current_state(
- room_id=self.room.to_string()
- )
- )
-
- self.assertEqual(1, len(state))
- self.assertObjectHasAttributes(
- {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic},
- state[0],
- )
-
- # Not testing the various 'level' methods for now because there's lots
- # of them and need coalescing; see JIRA SPEC-11
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1218786d79..8794401823 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -110,60 +110,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
- def test_get_joined_users_from_context(self) -> None:
- room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- bob_event = self.get_success(
- event_injection.inject_member_event(
- self.hs, room, self.u_bob, Membership.JOIN
- )
- )
-
- # first, create a regular event
- event, context = self.get_success(
- event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[bob_event.event_id],
- type="m.test.1",
- content={},
- )
- )
-
- users = self.get_success(
- self.store.get_joined_users_from_context(event, context)
- )
- self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
- # Regression test for #7376: create a state event whose key matches bob's
- # user_id, but which is *not* a membership event, and persist that; then check
- # that `get_joined_users_from_context` returns the correct users for the next event.
- non_member_event = self.get_success(
- event_injection.inject_event(
- self.hs,
- room_id=room,
- sender=self.u_bob,
- prev_event_ids=[bob_event.event_id],
- type="m.test.2",
- state_key=self.u_bob,
- content={},
- )
- )
- event, context = self.get_success(
- event_injection.create_event(
- self.hs,
- room_id=room,
- sender=self.u_alice,
- prev_event_ids=[non_member_event.event_id],
- type="m.test.3",
- content={},
- )
- )
- users = self.get_success(
- self.store.get_joined_users_from_context(event, context)
- )
- self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
@@ -212,6 +158,75 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Check that alice's display name is now None
self.assertEqual(row[0]["display_name"], None)
+ def test_room_is_locally_forgotten(self) -> None:
+ """Test that when the last local user has forgotten a room it is known as forgotten."""
+ # join two local and one remote user
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "join")
+ )
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_charlie.to_string(), "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # local users leave the room and the room is not forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(
+ event_injection.inject_member_event(self.hs, self.room, self.u_bob, "leave")
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # first user forgets the room, room is not forgotten
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # second (last local) user forgets the room and the room is forgotten
+ self.get_success(self.store.forget(self.u_bob, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ def test_join_locally_forgotten_room(self) -> None:
+ """Tests if a user joins a forgotten room the room is not forgotten anymore."""
+ self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after leaving and forget the room, it is forgotten
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "leave"
+ )
+ )
+ self.get_success(self.store.forget(self.u_alice, self.room))
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
+ # after rejoin the room is not forgotten anymore
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room, self.u_alice, "join"
+ )
+ )
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room))
+ )
+
class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8043bdbde2..5564161750 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -369,8 +369,8 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict_ids = cache_entry.value
self.assertEqual(cache_entry.full, False)
- self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
- self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
+ self.assertEqual(cache_entry.known_absent, set())
+ self.assertDictEqual(state_dict_ids, {})
############################################
# test that things work with a partial cache
@@ -387,7 +387,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
room_id = self.room.to_string()
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
@@ -412,7 +412,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
@@ -443,7 +443,7 @@ class StateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(is_all, False)
- self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ self.assertDictEqual({}, state_dict)
(state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
|