diff options
-rw-r--r-- | synapse/handlers/relations.py | 32 | ||||
-rw-r--r-- | synapse/storage/_base.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/__init__.py | 25 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 4 | ||||
-rw-r--r-- | tests/storage/test_relations.py | 38 |
5 files changed, 75 insertions, 28 deletions
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 4824635162..c8744e3ec7 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -124,7 +124,10 @@ class RelationsHandler: # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). - related_events, next_token = await self._main_store.get_relations_for_event( + ( + related_events, + next_token, + ) = await self._main_store.relations.get_relations_for_event( event_id=event_id, event=event, room_id=room_id, @@ -211,7 +214,7 @@ class RelationsHandler: ShadowBanError if the requester is shadow-banned """ related_event_ids = ( - await self._main_store.get_all_relations_for_event_with_types( + await self._main_store.relations.get_all_relations_for_event_with_types( event_id, relation_types ) ) @@ -250,7 +253,9 @@ class RelationsHandler: A map of event IDs to a list related events. """ - related_events = await self._main_store.get_references_for_events(event_ids) + related_events = await self._main_store.relations.get_references_for_events( + event_ids + ) # Avoid additional logic if there are no ignored users. if not ignored_users: @@ -304,7 +309,7 @@ class RelationsHandler: event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id] # Fetch thread summaries. - summaries = await self._main_store.get_thread_summaries(event_ids) + summaries = await self._main_store.relations.get_thread_summaries(event_ids) # Limit fetching whether the requester has participated in a thread to # events which are thread roots. @@ -320,7 +325,7 @@ class RelationsHandler: # For events the requester did not send, check the database for whether # the requester sent a threaded reply. participated.update( - await self._main_store.get_threads_participated( + await self._main_store.relations.get_threads_participated( [ event_id for event_id in thread_event_ids @@ -331,8 +336,10 @@ class RelationsHandler: ) # Then subtract off the results for any ignored users. - ignored_results = await self._main_store.get_threaded_messages_per_user( - thread_event_ids, ignored_users + ignored_results = ( + await self._main_store.relations.get_threaded_messages_per_user( + thread_event_ids, ignored_users + ) ) # A map of event ID to the thread aggregation. @@ -361,7 +368,10 @@ class RelationsHandler: continue # Attempt to find another event to use as the latest event. - potential_events, _ = await self._main_store.get_relations_for_event( + ( + potential_events, + _, + ) = await self._main_store.relations.get_relations_for_event( event_id, event, room_id, @@ -498,7 +508,7 @@ class RelationsHandler: Note that there is no use in limiting edits by ignored users since the parent event should be ignored in the first place if the user is ignored. """ - edits = await self._main_store.get_applicable_edits( + edits = await self._main_store.relations.get_applicable_edits( [ event_id for event_id, event in events_by_id.items() @@ -553,7 +563,7 @@ class RelationsHandler: # Note that ignored users are not passed into get_threads # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). - thread_roots, next_batch = await self._main_store.get_threads( + thread_roots, next_batch = await self._main_store.relations.get_threads( room_id=room_id, limit=limit, from_token=from_token ) @@ -565,7 +575,7 @@ class RelationsHandler: # For events the requester did not send, check the database for whether # the requester sent a threaded reply. participated.update( - await self._main_store.get_threads_participated( + await self._main_store.relations.get_threads_participated( [eid for eid, p in participated.items() if not p], user_id, ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 481fec72fe..a30ebda0c8 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -25,6 +25,7 @@ from synapse.util.caches.descriptors import CachedFunction if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases import DataStore logger = logging.getLogger(__name__) @@ -44,11 +45,14 @@ class SQLBaseStore(metaclass=ABCMeta): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", + datastore: Optional["DataStore"] = None, ): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self.db_pool = database + # A reference back to the root datastore. + self.datastore = datastore self.external_cached_functions: Dict[str, CachedFunction] = {} diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0032a92f49..eb035591bf 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, cast +import re +from typing import TYPE_CHECKING, Any, List, Match, Optional, Tuple, Type, cast from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig @@ -24,6 +25,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor @@ -121,7 +123,6 @@ class DataStore( UserErasureStore, MonthlyActiveUsersWorkerStore, StatsStore, - RelationsStore, CensorEventsStore, UIAuthStore, EventForwardExtremitiesStore, @@ -129,6 +130,13 @@ class DataStore( LockStore, SessionStore, ): + DATASTORE_CLASSES: List[Type[SQLBaseStore]] = [ + RelationsStore, + ] + + # XXX So mypy knows about dynamic properties. + relations: RelationsStore + def __init__( self, database: DatabasePool, @@ -141,6 +149,19 @@ class DataStore( super().__init__(database, db_conn, hs) + def repl(match: Match[str]) -> str: + return "_" + match.group(0).lower() + + for datastore_class in self.DATASTORE_CLASSES: + name = datastore_class.__name__ + if name.endswith("Store"): + name = name[: -len("Store")] + + name = re.sub(r"[A-Z]", repl, name)[1:] + + store = datastore_class(database, db_conn, hs, self) + setattr(self, name, store) + async def get_users(self) -> List[JsonDict]: """Function to retrieve a list of users in users table. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 4a6c6c724d..2a6af4f8ce 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -52,6 +52,7 @@ from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -95,8 +96,9 @@ class RelationsWorkerStore(SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", + datastore: "DataStore", ): - super().__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs, datastore) self.db_pool.updates.register_background_update_handler( "threads_backfill", self._backfill_threads diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py index cd1d00208b..9043c1cdd6 100644 --- a/tests/storage/test_relations.py +++ b/tests/storage/test_relations.py @@ -58,28 +58,28 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase): Ensure that get_thread_id only searches up the tree for threads. """ # The thread itself and children of it return the thread. - thread_id = self.get_success(self._main_store.get_thread_id("B")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("B")) self.assertEqual("A", thread_id) - thread_id = self.get_success(self._main_store.get_thread_id("C")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("C")) self.assertEqual("A", thread_id) # But the root and events related to the root do not. - thread_id = self.get_success(self._main_store.get_thread_id("A")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("A")) self.assertEqual(MAIN_TIMELINE, thread_id) - thread_id = self.get_success(self._main_store.get_thread_id("D")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("D")) self.assertEqual(MAIN_TIMELINE, thread_id) - thread_id = self.get_success(self._main_store.get_thread_id("E")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("E")) self.assertEqual(MAIN_TIMELINE, thread_id) # Events which are not related to a thread at all should return the # main timeline. - thread_id = self.get_success(self._main_store.get_thread_id("F")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("F")) self.assertEqual(MAIN_TIMELINE, thread_id) - thread_id = self.get_success(self._main_store.get_thread_id("G")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("G")) self.assertEqual(MAIN_TIMELINE, thread_id) def test_get_thread_id_for_receipts(self) -> None: @@ -87,25 +87,35 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase): Ensure that get_thread_id_for_receipts searches up and down the tree for a thread. """ # All of the events are considered related to this thread. - thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A")) + thread_id = self.get_success( + self._main_store.relations.get_thread_id_for_receipts("A") + ) self.assertEqual("A", thread_id) - thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B")) + thread_id = self.get_success( + self._main_store.relations.get_thread_id_for_receipts("B") + ) self.assertEqual("A", thread_id) - thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C")) + thread_id = self.get_success( + self._main_store.relations.get_thread_id_for_receipts("C") + ) self.assertEqual("A", thread_id) - thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D")) + thread_id = self.get_success( + self._main_store.relations.get_thread_id_for_receipts("D") + ) self.assertEqual("A", thread_id) - thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E")) + thread_id = self.get_success( + self._main_store.relations.get_thread_id_for_receipts("E") + ) self.assertEqual("A", thread_id) # Events which are not related to a thread at all should return the # main timeline. - thread_id = self.get_success(self._main_store.get_thread_id("F")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("F")) self.assertEqual(MAIN_TIMELINE, thread_id) - thread_id = self.get_success(self._main_store.get_thread_id("G")) + thread_id = self.get_success(self._main_store.relations.get_thread_id("G")) self.assertEqual(MAIN_TIMELINE, thread_id) |