summary refs log tree commit diff
diff options
context:
space:
mode:
authorSean Quah <seanq@matrix.org>2023-04-15 02:43:04 +0100
committerSean Quah <seanq@matrix.org>2023-04-15 02:52:42 +0100
commit07a5623059961afc3adec4534fb24b49db1a39c4 (patch)
tree9e602119a03698af72664db26c9770d933fa2143
parentDe-localpart `{Filtering,FilteringWorkerStore}.get_user_filter()` (diff)
downloadsynapse-squah/expand_localpart_columns_1.tar.xz
De-localpart `{Filtering,FilteringWorkerStore}.add_user_filter()` github/squah/expand_localpart_columns_1 squah/expand_localpart_columns_1
Signed-off-by: Sean Quah <seanq@matrix.org>
-rw-r--r--synapse/api/filtering.py6
-rw-r--r--synapse/rest/client/filter.py2
-rw-r--r--synapse/storage/databases/main/filtering.py10
-rw-r--r--tests/api/test_filtering.py13
-rw-r--r--tests/rest/client/test_filter.py2
5 files changed, 15 insertions, 18 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 870baff2c4..84e57ac2c1 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -170,11 +170,9 @@ class Filtering:
         result = await self.store.get_user_filter(user_id, filter_id)
         return FilterCollection(self._hs, result)
 
-    def add_user_filter(
-        self, user_localpart: str, user_filter: JsonDict
-    ) -> Awaitable[int]:
+    def add_user_filter(self, user_id: str, user_filter: JsonDict) -> Awaitable[int]:
         self.check_valid_filter(user_filter)
-        return self.store.add_user_filter(user_localpart, user_filter)
+        return self.store.add_user_filter(user_id, user_filter)
 
     # TODO(paul): surely we should probably add a delete_user_filter or
     #   replace_user_filter at some point? There's no REST API specified for
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 2db8dacf7c..156ae61d21 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -94,7 +94,7 @@ class CreateFilterRestServlet(RestServlet):
         set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
 
         filter_id = await self.filtering.add_user_filter(
-            user_localpart=target_user.localpart, user_filter=content
+            user_id=user_id, user_filter=content
         )
 
         return 200, {"filter_id": str(filter_id)}
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 24a60ab6f3..e4de5000d0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -44,7 +44,6 @@ class FilteringWorkerStore(SQLBaseStore):
         except ValueError:
             raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
 
-        user_localpart = UserID.from_string(user_id).localpart
         try:
             def_json = await self.db_pool.simple_select_one_onecol(
                 table="user_filters",
@@ -68,7 +67,8 @@ class FilteringWorkerStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-    async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
+    async def add_user_filter(self, user_id: str, user_filter: JsonDict) -> int:
+        user_localpart = UserID.from_string(user_id).localpart
         def_json = encode_canonical_json(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then
@@ -92,10 +92,10 @@ class FilteringWorkerStore(SQLBaseStore):
                 filter_id = max_id + 1
 
             sql = (
-                "INSERT INTO user_filters (user_id, filter_id, filter_json)"
-                "VALUES(?, ?, ?)"
+                "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
+                "VALUES(?, ?, ?, ?)"
             )
-            txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
+            txn.execute(sql, (user_id, user_localpart, filter_id, bytearray(def_json)))
 
             return filter_id
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 48f40da176..4a55aa96a5 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -34,7 +34,6 @@ from tests import unittest
 from tests.events.test_utils import MockEvent
 
 user_id = "@test_user:test"
-user_localpart = "test_user"
 user2_id = "@test_user2:test"
 
 
@@ -439,7 +438,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         user_filter_json = {"presence": {"senders": ["@foo:bar"]}}
         filter_id = self.get_success(
             self.datastore.add_user_filter(
-                user_localpart=user_localpart, user_filter=user_filter_json
+                user_id=user_id, user_filter=user_filter_json
             )
         )
         presence_states = [
@@ -467,7 +466,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
 
         filter_id = self.get_success(
             self.datastore.add_user_filter(
-                user_localpart=user_localpart + "2", user_filter=user_filter_json
+                user_id=user2_id, user_filter=user_filter_json
             )
         )
         presence_states = [
@@ -493,7 +492,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
         filter_id = self.get_success(
             self.datastore.add_user_filter(
-                user_localpart=user_localpart, user_filter=user_filter_json
+                user_id=user_id, user_filter=user_filter_json
             )
         )
         event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
@@ -510,7 +509,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
         filter_id = self.get_success(
             self.datastore.add_user_filter(
-                user_localpart=user_localpart, user_filter=user_filter_json
+                user_id=user_id, user_filter=user_filter_json
             )
         )
         event = MockEvent(
@@ -592,7 +591,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
 
         filter_id = self.get_success(
             self.filtering.add_user_filter(
-                user_localpart=user_localpart, user_filter=user_filter_json
+                user_id=user_id, user_filter=user_filter_json
             )
         )
 
@@ -611,7 +610,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
 
         filter_id = self.get_success(
             self.datastore.add_user_filter(
-                user_localpart=user_localpart, user_filter=user_filter_json
+                user_id=user_id, user_filter=user_filter_json
             )
         )
 
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 436a186f23..65eff4fe10 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -76,7 +76,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
     def test_get_filter(self) -> None:
         filter_id = self.get_success(
             self.filtering.add_user_filter(
-                user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+                user_id="@apple:test", user_filter=self.EXAMPLE_FILTER
             )
         )
         self.reactor.advance(1)