diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index e79e91e7eb..cd7a465e97 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -24,7 +24,7 @@ class Filtering(object):
def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id)
- result.addCallback(Filter)
+ result.addCallback(FilterCollection)
return result
def add_user_filter(self, user_localpart, user_filter):
@@ -131,125 +131,82 @@ class Filtering(object):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
-class Filter(object):
+class FilterCollection(object):
def __init__(self, filter_json):
self.filter_json = filter_json
+ self.room_timeline_filter = Filter(
+ self.filter_json.get("room", {}).get("timeline", {})
+ )
+
+ self.room_state_filter = Filter(
+ self.filter_json.get("room", {}).get("state", {})
+ )
+
+ self.room_ephemeral_filter = Filter(
+ self.filter_json.get("room", {}).get("ephemeral", {})
+ )
+
+ self.presence_filter = Filter(
+ self.filter_json.get("presence", {})
+ )
+
def timeline_limit(self):
- return self.filter_json.get("room", {}).get("timeline", {}).get("limit", 10)
+ return self.room_timeline_filter.limit()
def presence_limit(self):
- return self.filter_json.get("presence", {}).get("limit", 10)
+ return self.presence_filter.limit()
def ephemeral_limit(self):
- return self.filter_json.get("room", {}).get("ephemeral", {}).get("limit", 10)
+ return self.room_ephemeral_filter.limit()
def filter_presence(self, events):
- return self._filter_on_key(events, ["presence"])
+ return self.presence_filter.filter(events)
def filter_room_state(self, events):
- return self._filter_on_key(events, ["room", "state"])
+ return self.room_state_filter.filter(events)
def filter_room_timeline(self, events):
- return self._filter_on_key(events, ["room", "timeline"])
+ return self.room_timeline_filter.filter(events)
def filter_room_ephemeral(self, events):
- return self._filter_on_key(events, ["room", "ephemeral"])
-
- def _filter_on_key(self, events, keys):
- filter_json = self.filter_json
- if not filter_json:
- return events
-
- try:
- # extract the right definition from the filter
- definition = filter_json
- for key in keys:
- definition = definition[key]
- return self._filter_with_definition(events, definition)
- except KeyError:
- # return all events if definition isn't specified.
- return events
-
- def _filter_with_definition(self, events, definition):
- return [e for e in events if self._passes_definition(definition, e)]
-
- def _passes_definition(self, definition, event):
- """Check if the event passes the filter definition
- Args:
- definition(dict): The filter definition to check against
- event(dict or Event): The event to check
- Returns:
- True if the event passes the filter in the definition
- """
- if type(event) is dict:
- room_id = event.get("room_id")
- sender = event.get("sender")
- event_type = event["type"]
- else:
- room_id = getattr(event, "room_id", None)
- sender = getattr(event, "sender", None)
- event_type = event.type
- return self._event_passes_definition(
- definition, room_id, sender, event_type
- )
+ return self.room_ephemeral_filter.filter(events)
- def _event_passes_definition(self, definition, room_id, sender,
- event_type):
- """Check if the event passes through the given definition.
- Args:
- definition(dict): The definition to check against.
- room_id(str): The id of the room this event is in or None.
- sender(str): The sender of the event
- event_type(str): The type of the event.
- Returns:
- True if the event passes through the filter.
- """
- # Algorithm notes:
- # For each key in the definition, check the event meets the criteria:
- # * For types: Literal match or prefix match (if ends with wildcard)
- # * For senders/rooms: Literal match only
- # * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
- # and 'not_types' then it is treated as only being in 'not_types')
-
- # room checks
- if room_id is not None:
- allow_rooms = definition.get("rooms", None)
- reject_rooms = definition.get("not_rooms", None)
- if reject_rooms and room_id in reject_rooms:
- return False
- if allow_rooms and room_id not in allow_rooms:
- return False
+class Filter(object):
+ def __init__(self, filter_json):
+ self.filter_json = filter_json
- # sender checks
- if sender is not None:
- allow_senders = definition.get("senders", None)
- reject_senders = definition.get("not_senders", None)
- if reject_senders and sender in reject_senders:
- return False
- if allow_senders and sender not in allow_senders:
+ def check(self, event):
+ literal_keys = {
+ "rooms": lambda v: event.room_id == v,
+ "senders": lambda v: event.sender == v,
+ "types": lambda v: _matches_wildcard(event.type, v)
+ }
+
+ for name, match_func in literal_keys.items():
+ not_name = "not_%s" % (name,)
+ disallowed_values = self.filter_json.get(not_name, [])
+ if any(map(match_func, disallowed_values)):
return False
- # type checks
- if "not_types" in definition:
- for def_type in definition["not_types"]:
- if self._event_matches_type(event_type, def_type):
+ allowed_values = self.filter_json.get(name, None)
+ if allowed_values is not None:
+ if not any(map(match_func, allowed_values)):
return False
- if "types" in definition:
- included = False
- for def_type in definition["types"]:
- if self._event_matches_type(event_type, def_type):
- included = True
- break
- if not included:
- return False
return True
- def _event_matches_type(self, event_type, def_type):
- if def_type.endswith("*"):
- type_prefix = def_type[:-1]
- return event_type.startswith(type_prefix)
- else:
- return event_type == def_type
+ def filter(self, events):
+ return filter(self.check, events)
+
+ def limit(self):
+ return self.filter_json.get("limit", 10)
+
+
+def _matches_wildcard(actual_value, filter_value):
+ if filter_value.endswith("*"):
+ type_prefix = filter_value[:-1]
+ return actual_value.startswith(type_prefix)
+ else:
+ return actual_value == filter_value
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index fffecb24f5..5e27a859f9 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -23,7 +23,7 @@ from synapse.types import StreamToken
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id,
)
-from synapse.api.filtering import Filter
+from synapse.api.filtering import FilterCollection
from ._base import client_v2_pattern
import copy
@@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet):
user.localpart, filter_id
)
except:
- filter = Filter({})
+ filter = FilterCollection({})
sync_config = SyncConfig(
user=user,
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6942cdac51..9f9af2d783 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -23,10 +23,17 @@ from tests.utils import (
)
from synapse.types import UserID
-from synapse.api.filtering import Filter
+from synapse.api.filtering import FilterCollection, Filter
user_localpart = "test_user"
-MockEvent = namedtuple("MockEvent", "sender type room_id")
+# MockEvent = namedtuple("MockEvent", "sender type room_id")
+
+
+def MockEvent(**kwargs):
+ ev = NonCallableMock(spec_set=kwargs.keys())
+ ev.configure_mock(**kwargs)
+ return ev
+
class FilteringTestCase(unittest.TestCase):
@@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase):
)
self.filtering = hs.get_filtering()
- self.filter = Filter({})
self.datastore = hs.get_datastore()
@@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase):
type="m.room.message",
room_id="!foo:bar"
)
+
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_types_works_with_wildcards(self):
@@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_types_works_with_unknowns(self):
@@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_works_with_literals(self):
@@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_works_with_wildcards(self):
@@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_works_with_unknowns(self):
@@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_types_takes_priority_over_types(self):
@@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_senders_works_with_literals(self):
@@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_senders_works_with_unknowns(self):
@@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_senders_works_with_literals(self):
@@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_senders_works_with_unknowns(self):
@@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_senders_takes_priority_over_senders(self):
@@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_rooms_works_with_literals(self):
@@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_rooms_works_with_unknowns(self):
@@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_rooms_works_with_literals(self):
@@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_rooms_works_with_unknowns(self):
@@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_not_rooms_takes_priority_over_rooms(self):
@@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown"
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event(self):
@@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertTrue(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event_bad_sender(self):
@@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event_bad_room(self):
@@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!piggyshouse:muppets" # nope
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
def test_definition_combined_event_bad_type(self):
@@ -341,7 +348,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertFalse(
- self.filter._passes_definition(definition, event)
+ Filter(definition).check(event)
)
@defer.inlineCallbacks
@@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(
sender="@foo:bar",
type="m.profile",
- room_id="!foo:bar"
)
events = [event]
@@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(
sender="@foo:bar",
type="custom.avatar.3d.crazy",
- room_id="!foo:bar"
)
events = [event]
|