diff options
-rw-r--r-- | synapse/api/filtering.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/filter.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/sync.py | 2 | ||||
-rw-r--r-- | synapse/storage/databases/main/filtering.py | 33 | ||||
-rw-r--r-- | tests/api/test_filtering.py | 26 | ||||
-rw-r--r-- | tests/rest/client/test_filter.py | 2 |
6 files changed, 37 insertions, 32 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index b9f432cc23..870baff2c4 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -165,9 +165,9 @@ class Filtering: self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {}) async def get_user_filter( - self, user_localpart: str, filter_id: Union[int, str] + self, user_id: str, filter_id: Union[int, str] ) -> "FilterCollection": - result = await self.store.get_user_filter(user_localpart, filter_id) + result = await self.store.get_user_filter(user_id, filter_id) return FilterCollection(self._hs, result) def add_user_filter( diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index ab7d8c9419..2db8dacf7c 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet): try: filter_collection = await self.filtering.get_user_filter( - user_localpart=target_user.localpart, filter_id=filter_id_int + user_id=user_id, filter_id=filter_id_int ) except StoreError as e: if e.code != 404: diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 03b0578945..f4039c8450 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet): else: try: filter_collection = await self.filtering.get_user_filter( - user.localpart, filter_id + user.to_string(), filter_id ) except StoreError as err: if err.code != 404: diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 88be0f5f2f..24a60ab6f3 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -24,7 +24,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -34,8 +34,9 @@ if TYPE_CHECKING: class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( - self, user_localpart: str, filter_id: Union[int, str] + self, user_id: str, filter_id: Union[int, str] ) -> JsonDict: + user_localpart = UserID.from_string(user_id).localpart # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -43,13 +44,27 @@ class FilteringWorkerStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = await self.db_pool.simple_select_one_onecol( - table="user_filters", - keyvalues={"user_id": user_localpart, "filter_id": filter_id}, - retcol="filter_json", - allow_none=False, - desc="get_user_filter", - ) + user_localpart = UserID.from_string(user_id).localpart + try: + def_json = await self.db_pool.simple_select_one_onecol( + table="user_filters", + keyvalues={"full_user_id": user_id, "filter_id": filter_id}, + retcol="filter_json", + allow_none=False, + desc="get_user_filter", + ) + except StoreError as e: + if e.code == 404: + # Fall back to the `user_id` column. + def_json = await self.db_pool.simple_select_one_onecol( + table="user_filters", + keyvalues={"user_id": user_localpart, "filter_id": filter_id}, + retcol="filter_json", + allow_none=False, + desc="get_user_filter", + ) + else: + raise return db_to_json(def_json) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 6c6a9ab4b4..48f40da176 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -33,7 +33,9 @@ from synapse.util.frozenutils import freeze from tests import unittest from tests.events.test_utils import MockEvent +user_id = "@test_user:test" user_localpart = "test_user" +user2_id = "@test_user2:test" class FilteringTestCase(unittest.HomeserverTestCase): @@ -453,9 +455,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_presence(presence_states)) @@ -483,9 +483,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_presence(presence_states)) @@ -502,9 +500,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): events = [event] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_room_state(events=events)) @@ -523,9 +519,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): events = [event] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_room_state(events)) @@ -607,9 +601,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): user_filter_json, ( self.get_success( - self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 - ) + self.datastore.get_user_filter(user_id=user_id, filter_id=0) ) ), ) @@ -624,9 +616,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) self.assertEqual(filter.get_filter_json(), user_filter_json) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 91678abf13..436a186f23 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.get_success( - self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.store.get_user_filter(user_id="@apple:test", filter_id=0) ) self.pump() self.assertEqual(filter, self.EXAMPLE_FILTER) |