diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0d4e74d637..3aeec0dc0f 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,25 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import tests.unittest
import tests.utils
-class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_get_prev_events_for_room(self):
- room_id = '@ROOM:local'
+ room_id = "@ROOM:local"
# add a bunch of events and hashes to act as forward extremities
def insert_event(txn, i):
- event_id = '$event_%i:local' % i
+ event_id = "$event_%i:local" % i
txn.execute(
(
@@ -45,33 +40,194 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
txn.execute(
(
- 'INSERT INTO event_forward_extremities (room_id, event_id) '
- 'VALUES (?, ?)'
+ "INSERT INTO event_forward_extremities (room_id, event_id) "
+ "VALUES (?, ?)"
),
(room_id, event_id),
)
txn.execute(
(
- 'INSERT INTO event_reference_hashes '
- '(event_id, algorithm, hash) '
+ "INSERT INTO event_reference_hashes "
+ "(event_id, algorithm, hash) "
"VALUES (?, 'sha256', ?)"
),
- (event_id, b'ffff'),
+ (event_id, bytearray(b"ffff")),
)
- for i in range(0, 11):
- yield self.store.runInteraction("insert", insert_event, i)
+ for i in range(0, 20):
+ self.get_success(self.store.db.runInteraction("insert", insert_event, i))
- # this should get the last five and five others
- r = yield self.store.get_prev_events_for_room(room_id)
+ # this should get the last ten
+ r = self.get_success(self.store.get_prev_events_for_room(room_id))
self.assertEqual(10, len(r))
- for i in range(0, 5):
- el = r[i]
- depth = el[2]
- self.assertEqual(10 - i, depth)
-
- for i in range(5, 5):
- el = r[i]
- depth = el[2]
- self.assertLessEqual(5, depth)
+ for i in range(0, 10):
+ self.assertEqual("$event_%i:local" % (19 - i), r[i])
+
+ def test_get_rooms_with_many_extremities(self):
+ room1 = "#room1"
+ room2 = "#room2"
+ room3 = "#room3"
+
+ def insert_event(txn, i, room_id):
+ event_id = "$event_%i:local" % i
+ txn.execute(
+ (
+ "INSERT INTO event_forward_extremities (room_id, event_id) "
+ "VALUES (?, ?)"
+ ),
+ (room_id, event_id),
+ )
+
+ for i in range(0, 20):
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room1)
+ )
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room2)
+ )
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room3)
+ )
+
+ # Test simple case
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
+ self.assertEqual(len(r), 3)
+
+ # Does filter work?
+
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
+ self.assertTrue(room2 in r)
+ self.assertTrue(room3 in r)
+ self.assertEqual(len(r), 2)
+
+ r = self.get_success(
+ self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+ )
+ self.assertEqual(r, [room3])
+
+ # Does filter and limit work?
+
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
+ self.assertTrue(r == [room2] or r == [room3])
+
+ def test_auth_difference(self):
+ room_id = "@ROOM:local"
+
+ # The silly auth graph we use to test the auth difference algorithm,
+ # where the top are the most recent events.
+ #
+ # A B
+ # \ /
+ # D E
+ # \ |
+ # ` F C
+ # | /|
+ # G ยด |
+ # | \ |
+ # H I
+ # | |
+ # K J
+
+ auth_graph = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+ }
+
+ depth_map = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+ }
+
+ # We rudely fiddle with the appropriate tables directly, as that's much
+ # easier than constructing events properly.
+
+ def insert_event(txn, event_id, stream_ordering):
+
+ depth = depth_map[event_id]
+
+ self.store.db.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ self.store.db.simple_insert_many_txn(
+ txn,
+ table="event_auth",
+ values=[
+ {"event_id": event_id, "room_id": room_id, "auth_id": a}
+ for a in auth_graph[event_id]
+ ],
+ )
+
+ next_stream_ordering = 0
+ for event_id in auth_graph:
+ next_stream_ordering += 1
+ self.get_success(
+ self.store.db.runInteraction(
+ "insert", insert_event, event_id, next_stream_ordering
+ )
+ )
+
+ # Now actually test that various combinations give the right result:
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+ self.assertSetEqual(difference, set())
|