summary refs log tree commit diff
path: root/tests/api
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api')
-rw-r--r--tests/api/test_filtering.py108
1 files changed, 57 insertions, 51 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index aa93616a9f..babf4c37f1 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -25,6 +25,7 @@ from tests.utils import (
 
 from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.api.filtering import Filter
 
 user_localpart = "test_user"
 MockEvent = namedtuple("MockEvent", "sender type room_id")
@@ -53,6 +54,7 @@ class FilteringTestCase(unittest.TestCase):
         )
 
         self.filtering = hs.get_filtering()
+        self.filter = Filter({})
 
         self.datastore = hs.get_datastore()
 
@@ -66,7 +68,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_types_works_with_wildcards(self):
@@ -79,7 +81,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_types_works_with_unknowns(self):
@@ -92,7 +94,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_works_with_literals(self):
@@ -105,7 +107,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_works_with_wildcards(self):
@@ -118,7 +120,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_works_with_unknowns(self):
@@ -131,7 +133,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_types_takes_priority_over_types(self):
@@ -145,7 +147,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_senders_works_with_literals(self):
@@ -158,7 +160,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_senders_works_with_unknowns(self):
@@ -171,7 +173,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_senders_works_with_literals(self):
@@ -184,7 +186,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_senders_works_with_unknowns(self):
@@ -197,7 +199,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_senders_takes_priority_over_senders(self):
@@ -211,7 +213,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!foo:bar"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_rooms_works_with_literals(self):
@@ -224,7 +226,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_rooms_works_with_unknowns(self):
@@ -237,7 +239,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_rooms_works_with_literals(self):
@@ -250,7 +252,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_rooms_works_with_unknowns(self):
@@ -263,7 +265,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!anothersecretbase:unknown"
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_not_rooms_takes_priority_over_rooms(self):
@@ -277,7 +279,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!secretbase:unknown"
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event(self):
@@ -295,7 +297,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertTrue(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event_bad_sender(self):
@@ -313,7 +315,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event_bad_room(self):
@@ -331,7 +333,7 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!piggyshouse:muppets"  # nope
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     def test_definition_combined_event_bad_type(self):
@@ -349,12 +351,12 @@ class FilteringTestCase(unittest.TestCase):
             room_id="!stage:unknown"  # yup
         )
         self.assertFalse(
-            self.filtering._passes_definition(definition, event)
+            self.filter._passes_definition(definition, event)
         )
 
     @defer.inlineCallbacks
     def test_filter_public_user_data_match(self):
-        user_filter = {
+        user_filter_json = {
             "public_user_data": {
                 "types": ["m.*"]
             }
@@ -362,7 +364,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -371,16 +373,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_public_user_data(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_public_user_data(events=events)
         self.assertEquals(events, results)
 
     @defer.inlineCallbacks
     def test_filter_public_user_data_no_match(self):
-        user_filter = {
+        user_filter_json = {
             "public_user_data": {
                 "types": ["m.*"]
             }
@@ -388,7 +391,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -397,16 +400,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_public_user_data(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_public_user_data(events=events)
         self.assertEquals([], results)
 
     @defer.inlineCallbacks
     def test_filter_room_state_match(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -416,7 +420,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -425,16 +429,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_room_state(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_room_state(events=events)
         self.assertEquals(events, results)
 
     @defer.inlineCallbacks
     def test_filter_room_state_no_match(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -444,7 +449,7 @@ class FilteringTestCase(unittest.TestCase):
         user = UserID.from_string("@" + user_localpart + ":test")
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
         event = MockEvent(
             sender="@foo:bar",
@@ -453,16 +458,17 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        results = yield self.filtering.filter_room_state(
-            events=events,
-            user=user,
-            filter_id=filter_id
+        user_filter = yield self.filtering.get_user_filter(
+            user_localpart=user_localpart,
+            filter_id=filter_id,
         )
+
+        results = user_filter.filter_room_state(events)
         self.assertEquals([], results)
 
     @defer.inlineCallbacks
     def test_add_filter(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -472,11 +478,11 @@ class FilteringTestCase(unittest.TestCase):
 
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
 
         self.assertEquals(filter_id, 0)
-        self.assertEquals(user_filter,
+        self.assertEquals(user_filter_json,
             (yield self.datastore.get_user_filter(
                 user_localpart=user_localpart,
                 filter_id=0,
@@ -485,7 +491,7 @@ class FilteringTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_filter(self):
-        user_filter = {
+        user_filter_json = {
             "room": {
                 "state": {
                     "types": ["m.*"]
@@ -495,7 +501,7 @@ class FilteringTestCase(unittest.TestCase):
 
         filter_id = yield self.datastore.add_user_filter(
             user_localpart=user_localpart,
-            user_filter=user_filter,
+            user_filter=user_filter_json,
         )
 
         filter = yield self.filtering.get_user_filter(
@@ -503,4 +509,4 @@ class FilteringTestCase(unittest.TestCase):
             filter_id=filter_id,
         )
 
-        self.assertEquals(filter, user_filter)
+        self.assertEquals(filter.filter_json, user_filter_json)