diff --git a/changelog.d/11391.feature b/changelog.d/11391.feature
new file mode 100644
index 0000000000..4f696285a7
--- /dev/null
+++ b/changelog.d/11391.feature
@@ -0,0 +1 @@
+Store and allow querying of arbitrary event relations.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 120e4807d1..06832221ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1,6 +1,6 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1696,34 +1696,33 @@ class PersistEventsStore:
},
)
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
+ def _handle_event_relations(
+ self, txn: LoggingTransaction, event: EventBase
+ ) -> None:
+ """Handles inserting relation data during persistence of events
Args:
- txn
- event (EventBase)
+ txn: The current database transaction.
+ event: The event which might have relations.
"""
relation = event.content.get("m.relates_to")
if not relation:
# No relations
return
+ # Relations must have a type and parent event ID.
rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- RelationTypes.THREAD,
- ):
- # Unknown relation type
+ if not isinstance(rel_type, str):
return
parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
+ if not isinstance(parent_id, str):
return
- aggregation_key = relation.get("key")
+ # Annotations have a key field.
+ aggregation_key = None
+ if rel_type == RelationTypes.ANNOTATION:
+ aggregation_key = relation.get("key")
self.db_pool.simple_insert_txn(
txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index ae3a8a63e4..c88fd35e7f 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ # The event_thread_relation background update was replaced with the
+ # event_arbitrary_relations one, which handles any relation to avoid
+ # needed to potentially crawl the entire events table in the future.
+ self.db_pool.updates.register_noop_background_update("event_thread_relation")
+
self.db_pool.updates.register_background_update_handler(
- "event_thread_relation", self._event_thread_relation
+ "event_arbitrary_relations",
+ self._event_arbitrary_relations,
)
################################################################################
@@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return result
- async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
- """Background update handler which will store thread relations for existing events."""
+ async def _event_arbitrary_relations(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update handler which will store previously unknown relations for existing events."""
last_event_id = progress.get("last_event_id", "")
- def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+ def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
+ # Fetch events and then filter based on whether the event has a
+ # relation or not.
txn.execute(
"""
SELECT event_id, json FROM event_json
- LEFT JOIN event_relations USING (event_id)
- WHERE event_id > ? AND event_relations.event_id IS NULL
+ WHERE event_id > ?
ORDER BY event_id LIMIT ?
""",
(last_event_id, batch_size),
)
results = list(txn)
- missing_thread_relations = []
+ # (event_id, parent_id, rel_type) for each relation
+ relations_to_insert: List[Tuple[str, str, str]] = []
for (event_id, event_json_raw) in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
continue
- # If there's no relation (or it is not a thread), skip!
+ # If there's no relation, skip!
relates_to = event_json["content"].get("m.relates_to")
if not relates_to or not isinstance(relates_to, dict):
continue
- if relates_to.get("rel_type") != RelationTypes.THREAD:
+
+ # If the relation type or parent event ID is not a string, skip it.
+ #
+ # Do not consider relation types that have existed for a long time,
+ # since they will already be listed in the `event_relations` table.
+ rel_type = relates_to.get("rel_type")
+ if not isinstance(rel_type, str) or rel_type in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
continue
- # Get the parent ID.
parent_id = relates_to.get("event_id")
if not isinstance(parent_id, str):
continue
- missing_thread_relations.append((event_id, parent_id))
+ relations_to_insert.append((event_id, parent_id, rel_type))
+
+ # Insert the missing data, note that we upsert here in case the event
+ # has already been processed.
+ if relations_to_insert:
+ self.db_pool.simple_upsert_many_txn(
+ txn=txn,
+ table="event_relations",
+ key_names=("event_id",),
+ key_values=[(r[0],) for r in relations_to_insert],
+ value_names=("relates_to_id", "relation_type"),
+ value_values=[r[1:] for r in relations_to_insert],
+ )
- # Insert the missing data.
- self.db_pool.simple_insert_many_txn(
- txn=txn,
- table="event_relations",
- values=[
- {
- "event_id": event_id,
- "relates_to_Id": parent_id,
- "relation_type": RelationTypes.THREAD,
- }
- for event_id, parent_id in missing_thread_relations
- ],
- )
+ # Iterate the parent IDs and invalidate caches.
+ for parent_id in {r[1] for r in relations_to_insert}:
+ cache_tuple = (parent_id,)
+ self._invalidate_cache_and_stream(
+ txn, self.get_relations_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_aggregation_groups_for_event, cache_tuple
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_thread_summary, cache_tuple
+ )
if results:
latest_event_id = results[-1][0]
self.db_pool.updates._background_update_progress_txn(
- txn, "event_thread_relation", {"last_event_id": latest_event_id}
+ txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
)
return len(results)
num_rows = await self.db_pool.runInteraction(
- desc="event_thread_relation", func=_event_thread_relation_txn
+ desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
)
if not num_rows:
- await self.db_pool.updates._end_background_update("event_thread_relation")
+ await self.db_pool.updates._end_background_update(
+ "event_arbitrary_relations"
+ )
return num_rows
diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
index d60517f7b4..267b2cb539 100644
--- a/synapse/storage/schema/main/delta/65/02_thread_relations.sql
+++ b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
@@ -15,4 +15,4 @@
-- Check old events for thread relations.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
- (6502, 'event_thread_relation', '{}');
+ (6507, 'event_arbitrary_relations', '{}');
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index b8a1b92a89..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
# Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return config
def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
self.user_id, self.user_token = self._create_user("alice")
self.user2_id, self.user2_token = self._create_user("bob")
@@ -765,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIn("chunk", channel.json_body)
self.assertEquals(channel.json_body["chunk"], [])
+ def test_unknown_relations(self):
+ """Unknown relations should be accepted."""
+ channel = self._send_relation("m.relation.test", "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to get back a single pagination result, which is the full
+ # relation event we sent above.
+ self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+ self.assert_dict(
+ {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+ channel.json_body["chunk"][0],
+ )
+
+ # We also expect to get the original event (the id of which is self.parent_id)
+ self.assertEquals(
+ channel.json_body["original_event"]["event_id"], self.parent_id
+ )
+
+ # When bundling the unknown relation is not included.
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+ # But unknown relations can be directly queried.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+ % (self.room, self.parent_id),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
+
def _send_relation(
self,
relation_type: str,
@@ -811,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
access_token = self.login(localpart, "abc123")
return user_id, access_token
+
+ def test_background_update(self):
+ """Test the event_arbitrary_relations background update."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_good = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_event_id_bad = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ self.assertEquals(200, channel.code, channel.json_body)
+ thread_event_id = channel.json_body["event_id"]
+
+ # Clean-up the table as if the inserts did not happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="event_relations",
+ column="event_id",
+ iterable=(annotation_event_id_bad, thread_event_id),
+ keyvalues={},
+ desc="RelationsTestCase.test_background_update",
+ )
+ )
+
+ # Only the "good" annotation should be found.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEquals(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good],
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # The "good" annotation and the thread should be found, but not the "bad"
+ # annotation.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertCountEqual(
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
+ [annotation_event_id_good, thread_event_id],
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index c9a08a3420..165aafc574 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,7 +331,12 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """Block until all background database updates have completed."""
+ """
+ Block until all background database updates have completed.
+
+ Note that callers must ensure that's a store property created on the
+ testcase.
+ """
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
|