summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/filtering.py4
-rw-r--r--synapse/rest/client/filter.py2
-rw-r--r--synapse/rest/client/sync.py2
-rw-r--r--synapse/storage/databases/main/filtering.py33
-rw-r--r--tests/api/test_filtering.py26
-rw-r--r--tests/rest/client/test_filter.py2
6 files changed, 37 insertions, 32 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index b9f432cc23..870baff2c4 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -165,9 +165,9 @@ class Filtering:
         self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
 
     async def get_user_filter(
-        self, user_localpart: str, filter_id: Union[int, str]
+        self, user_id: str, filter_id: Union[int, str]
     ) -> "FilterCollection":
-        result = await self.store.get_user_filter(user_localpart, filter_id)
+        result = await self.store.get_user_filter(user_id, filter_id)
         return FilterCollection(self._hs, result)
 
     def add_user_filter(
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index ab7d8c9419..2db8dacf7c 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
 
         try:
             filter_collection = await self.filtering.get_user_filter(
-                user_localpart=target_user.localpart, filter_id=filter_id_int
+                user_id=user_id, filter_id=filter_id_int
             )
         except StoreError as e:
             if e.code != 404:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 03b0578945..f4039c8450 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
         else:
             try:
                 filter_collection = await self.filtering.get_user_filter(
-                    user.localpart, filter_id
+                    user.to_string(), filter_id
                 )
             except StoreError as err:
                 if err.code != 404:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 88be0f5f2f..24a60ab6f3 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -24,7 +24,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -34,8 +34,9 @@ if TYPE_CHECKING:
 class FilteringWorkerStore(SQLBaseStore):
     @cached(num_args=2)
     async def get_user_filter(
-        self, user_localpart: str, filter_id: Union[int, str]
+        self, user_id: str, filter_id: Union[int, str]
     ) -> JsonDict:
+        user_localpart = UserID.from_string(user_id).localpart
         # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
         # with a coherent error message rather than 500 M_UNKNOWN.
         try:
@@ -43,13 +44,27 @@ class FilteringWorkerStore(SQLBaseStore):
         except ValueError:
             raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
 
-        def_json = await self.db_pool.simple_select_one_onecol(
-            table="user_filters",
-            keyvalues={"user_id": user_localpart, "filter_id": filter_id},
-            retcol="filter_json",
-            allow_none=False,
-            desc="get_user_filter",
-        )
+        user_localpart = UserID.from_string(user_id).localpart
+        try:
+            def_json = await self.db_pool.simple_select_one_onecol(
+                table="user_filters",
+                keyvalues={"full_user_id": user_id, "filter_id": filter_id},
+                retcol="filter_json",
+                allow_none=False,
+                desc="get_user_filter",
+            )
+        except StoreError as e:
+            if e.code == 404:
+                # Fall back to the `user_id` column.
+                def_json = await self.db_pool.simple_select_one_onecol(
+                    table="user_filters",
+                    keyvalues={"user_id": user_localpart, "filter_id": filter_id},
+                    retcol="filter_json",
+                    allow_none=False,
+                    desc="get_user_filter",
+                )
+            else:
+                raise
 
         return db_to_json(def_json)
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6c6a9ab4b4..48f40da176 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -33,7 +33,9 @@ from synapse.util.frozenutils import freeze
 from tests import unittest
 from tests.events.test_utils import MockEvent
 
+user_id = "@test_user:test"
 user_localpart = "test_user"
+user2_id = "@test_user2:test"
 
 
 class FilteringTestCase(unittest.HomeserverTestCase):
@@ -453,9 +455,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         ]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_presence(presence_states))
@@ -483,9 +483,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         ]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart + "2", filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_presence(presence_states))
@@ -502,9 +500,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         events = [event]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_room_state(events=events))
@@ -523,9 +519,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         events = [event]
 
         user_filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         results = self.get_success(user_filter.filter_room_state(events))
@@ -607,9 +601,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             user_filter_json,
             (
                 self.get_success(
-                    self.datastore.get_user_filter(
-                        user_localpart=user_localpart, filter_id=0
-                    )
+                    self.datastore.get_user_filter(user_id=user_id, filter_id=0)
                 )
             ),
         )
@@ -624,9 +616,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         )
 
         filter = self.get_success(
-            self.filtering.get_user_filter(
-                user_localpart=user_localpart, filter_id=filter_id
-            )
+            self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
         )
 
         self.assertEqual(filter.get_filter_json(), user_filter_json)
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 91678abf13..436a186f23 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"filter_id": "0"})
         filter = self.get_success(
-            self.store.get_user_filter(user_localpart="apple", filter_id=0)
+            self.store.get_user_filter(user_id="@apple:test", filter_id=0)
         )
         self.pump()
         self.assertEqual(filter, self.EXAMPLE_FILTER)