diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 67401272ac..32a798d74b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -35,66 +35,45 @@ from synapse.util import Clock
from synapse.util.async_helpers import yieldable_gather_results
from tests import unittest
+from tests.test_utils.event_injection import create_event, inject_event
class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
def prepare(self, reactor, clock, hs):
+ self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main
- # insert some test data
- for rid in ("room1", "room2"):
- self.get_success(
- self.store.db_pool.simple_insert(
- "rooms",
- {"room_id": rid, "room_version": 4},
- )
- )
+ self.user = self.register_user("user", "pass")
+ self.token = self.login(self.user, "pass")
+ self.room_id = self.helper.create_room_as(self.user, tok=self.token)
self.event_ids: List[str] = []
- for idx, rid in enumerate(
- (
- "room1",
- "room1",
- "room1",
- "room2",
- )
- ):
- event_json = {"type": f"test {idx}", "room_id": rid}
- event = make_event_from_dict(event_json, room_version=RoomVersions.V4)
- event_id = event.event_id
-
- self.get_success(
- self.store.db_pool.simple_insert(
- "events",
- {
- "event_id": event_id,
- "room_id": rid,
- "topological_ordering": idx,
- "stream_ordering": idx,
- "type": event.type,
- "processed": True,
- "outlier": False,
- },
+ for i in range(3):
+ event = self.get_success(
+ inject_event(
+ hs,
+ room_version=RoomVersions.V7.identifier,
+ room_id=self.room_id,
+ sender=self.user,
+ type="test_event_type",
+ content={"body": f"foobarbaz{i}"},
)
)
- self.get_success(
- self.store.db_pool.simple_insert(
- "event_json",
- {
- "event_id": event_id,
- "room_id": rid,
- "json": json.dumps(event_json),
- "internal_metadata": "{}",
- "format_version": 3,
- },
- )
- )
- self.event_ids.append(event_id)
+
+ self.event_ids.append(event.event_id)
def test_simple(self):
with LoggingContext(name="test") as ctx:
res = self.get_success(
- self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+ self.store.have_seen_events(
+ self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+ )
)
self.assertEqual(res, {self.event_ids[0]})
@@ -104,7 +83,9 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# a second lookup of the same events should cause no queries
with LoggingContext(name="test") as ctx:
res = self.get_success(
- self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
+ self.store.have_seen_events(
+ self.room_id, [self.event_ids[0], "eventdoesnotexist"]
+ )
)
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
@@ -116,11 +97,86 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# looking it up should now cause no db hits
with LoggingContext(name="test") as ctx:
res = self.get_success(
- self.store.have_seen_events("room1", [self.event_ids[0]])
+ self.store.have_seen_events(self.room_id, [self.event_ids[0]])
)
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
+ def test_persisting_event_invalidates_cache(self):
+ """
+ Test to make sure that the `have_seen_event` cache
+ is invalidated after we persist an event and returns
+ the updated value.
+ """
+ event, event_context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=self.user,
+ type="test_event_type",
+ content={"body": "garply"},
+ )
+ )
+
+ with LoggingContext(name="test") as ctx:
+ # First, check `have_seen_event` for an event we have not seen yet
+ # to prime the cache with a `false` value.
+ res = self.get_success(
+ self.store.have_seen_events(event.room_id, [event.event_id])
+ )
+ self.assertEqual(res, set())
+
+ # That should result in a single db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+ # Persist the event which should invalidate or prefill the
+ # `have_seen_event` cache so we don't return stale values.
+ persistence = self.hs.get_storage_controllers().persistence
+ self.get_success(
+ persistence.persist_event(
+ event,
+ event_context,
+ )
+ )
+
+ with LoggingContext(name="test") as ctx:
+ # Check `have_seen_event` again and we should see the updated fact
+ # that we have now seen the event after persisting it.
+ res = self.get_success(
+ self.store.have_seen_events(event.room_id, [event.event_id])
+ )
+ self.assertEqual(res, {event.event_id})
+
+ # That should result in a single db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+ def test_invalidate_cache_by_room_id(self):
+ """
+ Test to make sure that all events associated with the given `(room_id,)`
+ are invalidated in the `have_seen_event` cache.
+ """
+ with LoggingContext(name="test") as ctx:
+ # Prime the cache with some values
+ res = self.get_success(
+ self.store.have_seen_events(self.room_id, self.event_ids)
+ )
+ self.assertEqual(res, set(self.event_ids))
+
+ # That should result in a single db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
+ # Clear the cache with any events associated with the `room_id`
+ self.store.have_seen_event.invalidate((self.room_id,))
+
+ with LoggingContext(name="test") as ctx:
+ res = self.get_success(
+ self.store.have_seen_events(self.room_id, self.event_ids)
+ )
+ self.assertEqual(res, set(self.event_ids))
+
+ # Since we cleared the cache, it should result in another db query to lookup
+ self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+
class EventCacheTestCase(unittest.HomeserverTestCase):
"""Test that the various layers of event cache works."""
|