diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index b9f432cc23..870baff2c4 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -165,9 +165,9 @@ class Filtering:
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
async def get_user_filter(
- self, user_localpart: str, filter_id: Union[int, str]
+ self, user_id: str, filter_id: Union[int, str]
) -> "FilterCollection":
- result = await self.store.get_user_filter(user_localpart, filter_id)
+ result = await self.store.get_user_filter(user_id, filter_id)
return FilterCollection(self._hs, result)
def add_user_filter(
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index ab7d8c9419..2db8dacf7c 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
try:
filter_collection = await self.filtering.get_user_filter(
- user_localpart=target_user.localpart, filter_id=filter_id_int
+ user_id=user_id, filter_id=filter_id_int
)
except StoreError as e:
if e.code != 404:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 03b0578945..f4039c8450 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
else:
try:
filter_collection = await self.filtering.get_user_filter(
- user.localpart, filter_id
+ user.to_string(), filter_id
)
except StoreError as err:
if err.code != 404:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 88be0f5f2f..24a60ab6f3 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -24,7 +24,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -34,8 +34,9 @@ if TYPE_CHECKING:
class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2)
async def get_user_filter(
- self, user_localpart: str, filter_id: Union[int, str]
+ self, user_id: str, filter_id: Union[int, str]
) -> JsonDict:
+ user_localpart = UserID.from_string(user_id).localpart
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -43,13 +44,27 @@ class FilteringWorkerStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = await self.db_pool.simple_select_one_onecol(
- table="user_filters",
- keyvalues={"user_id": user_localpart, "filter_id": filter_id},
- retcol="filter_json",
- allow_none=False,
- desc="get_user_filter",
- )
+ user_localpart = UserID.from_string(user_id).localpart
+ try:
+ def_json = await self.db_pool.simple_select_one_onecol(
+ table="user_filters",
+ keyvalues={"full_user_id": user_id, "filter_id": filter_id},
+ retcol="filter_json",
+ allow_none=False,
+ desc="get_user_filter",
+ )
+ except StoreError as e:
+ if e.code == 404:
+ # Fall back to the `user_id` column.
+ def_json = await self.db_pool.simple_select_one_onecol(
+ table="user_filters",
+ keyvalues={"user_id": user_localpart, "filter_id": filter_id},
+ retcol="filter_json",
+ allow_none=False,
+ desc="get_user_filter",
+ )
+ else:
+ raise
return db_to_json(def_json)
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6c6a9ab4b4..48f40da176 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -33,7 +33,9 @@ from synapse.util.frozenutils import freeze
from tests import unittest
from tests.events.test_utils import MockEvent
+user_id = "@test_user:test"
user_localpart = "test_user"
+user2_id = "@test_user2:test"
class FilteringTestCase(unittest.HomeserverTestCase):
@@ -453,9 +455,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@@ -483,9 +483,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart + "2", filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@@ -502,9 +500,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
events = [event]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events=events))
@@ -523,9 +519,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
events = [event]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events))
@@ -607,9 +601,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json,
(
self.get_success(
- self.datastore.get_user_filter(
- user_localpart=user_localpart, filter_id=0
- )
+ self.datastore.get_user_filter(user_id=user_id, filter_id=0)
)
),
)
@@ -624,9 +616,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
self.assertEqual(filter.get_filter_json(), user_filter_json)
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 91678abf13..436a186f23 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
- self.store.get_user_filter(user_localpart="apple", filter_id=0)
+ self.store.get_user_filter(user_id="@apple:test", filter_id=0)
)
self.pump()
self.assertEqual(filter, self.EXAMPLE_FILTER)
|