summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-12-14 12:08:21 -0500
committerPatrick Cloke <patrickc@matrix.org>2023-05-17 14:26:01 -0400
commitc25ec34d732952dcc1a4ecb89652f11a9cd43a48 (patch)
tree4728156955ec3f05fa52330af12becd5980f4321
parentAdd a new admin API to create a new device for a user. (#15611) (diff)
downloadsynapse-c25ec34d732952dcc1a4ecb89652f11a9cd43a48.tar.xz
✨ Magic ✨
-rw-r--r--synapse/handlers/relations.py32
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/databases/main/__init__.py25
-rw-r--r--synapse/storage/databases/main/relations.py4
-rw-r--r--tests/storage/test_relations.py38
5 files changed, 75 insertions, 28 deletions
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 4824635162..c8744e3ec7 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -124,7 +124,10 @@ class RelationsHandler:
         # Note that ignored users are not passed into get_relations_for_event
         # below. Ignored users are handled in filter_events_for_client (and by
         # not passing them in here we should get a better cache hit rate).
-        related_events, next_token = await self._main_store.get_relations_for_event(
+        (
+            related_events,
+            next_token,
+        ) = await self._main_store.relations.get_relations_for_event(
             event_id=event_id,
             event=event,
             room_id=room_id,
@@ -211,7 +214,7 @@ class RelationsHandler:
             ShadowBanError if the requester is shadow-banned
         """
         related_event_ids = (
-            await self._main_store.get_all_relations_for_event_with_types(
+            await self._main_store.relations.get_all_relations_for_event_with_types(
                 event_id, relation_types
             )
         )
@@ -250,7 +253,9 @@ class RelationsHandler:
             A map of event IDs to a list related events.
         """
 
-        related_events = await self._main_store.get_references_for_events(event_ids)
+        related_events = await self._main_store.relations.get_references_for_events(
+            event_ids
+        )
 
         # Avoid additional logic if there are no ignored users.
         if not ignored_users:
@@ -304,7 +309,7 @@ class RelationsHandler:
         event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
 
         # Fetch thread summaries.
-        summaries = await self._main_store.get_thread_summaries(event_ids)
+        summaries = await self._main_store.relations.get_thread_summaries(event_ids)
 
         # Limit fetching whether the requester has participated in a thread to
         # events which are thread roots.
@@ -320,7 +325,7 @@ class RelationsHandler:
         # For events the requester did not send, check the database for whether
         # the requester sent a threaded reply.
         participated.update(
-            await self._main_store.get_threads_participated(
+            await self._main_store.relations.get_threads_participated(
                 [
                     event_id
                     for event_id in thread_event_ids
@@ -331,8 +336,10 @@ class RelationsHandler:
         )
 
         # Then subtract off the results for any ignored users.
-        ignored_results = await self._main_store.get_threaded_messages_per_user(
-            thread_event_ids, ignored_users
+        ignored_results = (
+            await self._main_store.relations.get_threaded_messages_per_user(
+                thread_event_ids, ignored_users
+            )
         )
 
         # A map of event ID to the thread aggregation.
@@ -361,7 +368,10 @@ class RelationsHandler:
                     continue
 
                 # Attempt to find another event to use as the latest event.
-                potential_events, _ = await self._main_store.get_relations_for_event(
+                (
+                    potential_events,
+                    _,
+                ) = await self._main_store.relations.get_relations_for_event(
                     event_id,
                     event,
                     room_id,
@@ -498,7 +508,7 @@ class RelationsHandler:
             Note that there is no use in limiting edits by ignored users since the
             parent event should be ignored in the first place if the user is ignored.
             """
-            edits = await self._main_store.get_applicable_edits(
+            edits = await self._main_store.relations.get_applicable_edits(
                 [
                     event_id
                     for event_id, event in events_by_id.items()
@@ -553,7 +563,7 @@ class RelationsHandler:
         # Note that ignored users are not passed into get_threads
         # below. Ignored users are handled in filter_events_for_client (and by
         # not passing them in here we should get a better cache hit rate).
-        thread_roots, next_batch = await self._main_store.get_threads(
+        thread_roots, next_batch = await self._main_store.relations.get_threads(
             room_id=room_id, limit=limit, from_token=from_token
         )
 
@@ -565,7 +575,7 @@ class RelationsHandler:
             # For events the requester did not send, check the database for whether
             # the requester sent a threaded reply.
             participated.update(
-                await self._main_store.get_threads_participated(
+                await self._main_store.relations.get_threads_participated(
                     [eid for eid, p in participated.items() if not p],
                     user_id,
                 )
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 481fec72fe..a30ebda0c8 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -25,6 +25,7 @@ from synapse.util.caches.descriptors import CachedFunction
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -44,11 +45,14 @@ class SQLBaseStore(metaclass=ABCMeta):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
+        datastore: Optional["DataStore"] = None,
     ):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
         self.db_pool = database
+        # A reference back to the root datastore.
+        self.datastore = datastore
 
         self.external_cached_functions: Dict[str, CachedFunction] = {}
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0032a92f49..eb035591bf 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,8 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+import re
+from typing import TYPE_CHECKING, Any, List, Match, Optional, Tuple, Type, cast
 
 from synapse.api.constants import Direction
 from synapse.config.homeserver import HomeServerConfig
@@ -24,6 +25,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.types import Cursor
@@ -121,7 +123,6 @@ class DataStore(
     UserErasureStore,
     MonthlyActiveUsersWorkerStore,
     StatsStore,
-    RelationsStore,
     CensorEventsStore,
     UIAuthStore,
     EventForwardExtremitiesStore,
@@ -129,6 +130,13 @@ class DataStore(
     LockStore,
     SessionStore,
 ):
+    DATASTORE_CLASSES: List[Type[SQLBaseStore]] = [
+        RelationsStore,
+    ]
+
+    # XXX So mypy knows about dynamic properties.
+    relations: RelationsStore
+
     def __init__(
         self,
         database: DatabasePool,
@@ -141,6 +149,19 @@ class DataStore(
 
         super().__init__(database, db_conn, hs)
 
+        def repl(match: Match[str]) -> str:
+            return "_" + match.group(0).lower()
+
+        for datastore_class in self.DATASTORE_CLASSES:
+            name = datastore_class.__name__
+            if name.endswith("Store"):
+                name = name[: -len("Store")]
+
+            name = re.sub(r"[A-Z]", repl, name)[1:]
+
+            store = datastore_class(database, db_conn, hs, self)
+            setattr(self, name, store)
+
     async def get_users(self) -> List[JsonDict]:
         """Function to retrieve a list of users in users table.
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4a6c6c724d..2a6af4f8ce 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -52,6 +52,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -95,8 +96,9 @@ class RelationsWorkerStore(SQLBaseStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
+        datastore: "DataStore",
     ):
-        super().__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs, datastore)
 
         self.db_pool.updates.register_background_update_handler(
             "threads_backfill", self._backfill_threads
diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py
index cd1d00208b..9043c1cdd6 100644
--- a/tests/storage/test_relations.py
+++ b/tests/storage/test_relations.py
@@ -58,28 +58,28 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
         Ensure that get_thread_id only searches up the tree for threads.
         """
         # The thread itself and children of it return the thread.
-        thread_id = self.get_success(self._main_store.get_thread_id("B"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("B"))
         self.assertEqual("A", thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id("C"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("C"))
         self.assertEqual("A", thread_id)
 
         # But the root and events related to the root do not.
-        thread_id = self.get_success(self._main_store.get_thread_id("A"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("A"))
         self.assertEqual(MAIN_TIMELINE, thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id("D"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("D"))
         self.assertEqual(MAIN_TIMELINE, thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id("E"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("E"))
         self.assertEqual(MAIN_TIMELINE, thread_id)
 
         # Events which are not related to a thread at all should return the
         # main timeline.
-        thread_id = self.get_success(self._main_store.get_thread_id("F"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
         self.assertEqual(MAIN_TIMELINE, thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id("G"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
         self.assertEqual(MAIN_TIMELINE, thread_id)
 
     def test_get_thread_id_for_receipts(self) -> None:
@@ -87,25 +87,35 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
         Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
         """
         # All of the events are considered related to this thread.
-        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
+        thread_id = self.get_success(
+            self._main_store.relations.get_thread_id_for_receipts("A")
+        )
         self.assertEqual("A", thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
+        thread_id = self.get_success(
+            self._main_store.relations.get_thread_id_for_receipts("B")
+        )
         self.assertEqual("A", thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
+        thread_id = self.get_success(
+            self._main_store.relations.get_thread_id_for_receipts("C")
+        )
         self.assertEqual("A", thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
+        thread_id = self.get_success(
+            self._main_store.relations.get_thread_id_for_receipts("D")
+        )
         self.assertEqual("A", thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
+        thread_id = self.get_success(
+            self._main_store.relations.get_thread_id_for_receipts("E")
+        )
         self.assertEqual("A", thread_id)
 
         # Events which are not related to a thread at all should return the
         # main timeline.
-        thread_id = self.get_success(self._main_store.get_thread_id("F"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
         self.assertEqual(MAIN_TIMELINE, thread_id)
 
-        thread_id = self.get_success(self._main_store.get_thread_id("G"))
+        thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
         self.assertEqual(MAIN_TIMELINE, thread_id)