diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index e16c0e559f..b7e5d3222f 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -25,127 +25,25 @@ class Filtering(object):
self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id):
- return self.store.get_user_filter(user_localpart, filter_id)
+ result = self.store.get_user_filter(user_localpart, filter_id)
+ result.addCallback(Filter)
+ return result
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
- def filter_public_user_data(self, events, user, filter_id):
- return self._filter_on_key(
- events, user, filter_id, ["public_user_data"]
- )
-
- def filter_private_user_data(self, events, user, filter_id):
- return self._filter_on_key(
- events, user, filter_id, ["private_user_data"]
- )
-
- def filter_room_state(self, events, user, filter_id):
- return self._filter_on_key(
- events, user, filter_id, ["room", "state"]
- )
-
- def filter_room_events(self, events, user, filter_id):
- return self._filter_on_key(
- events, user, filter_id, ["room", "events"]
- )
-
- def filter_room_ephemeral(self, events, user, filter_id):
- return self._filter_on_key(
- events, user, filter_id, ["room", "ephemeral"]
- )
-
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
- @defer.inlineCallbacks
- def _filter_on_key(self, events, user, filter_id, keys):
- filter_json = yield self.get_user_filter(user.localpart, filter_id)
- if not filter_json:
- defer.returnValue(events)
-
- try:
- # extract the right definition from the filter
- definition = filter_json
- for key in keys:
- definition = definition[key]
- defer.returnValue(self._filter_with_definition(events, definition))
- except KeyError:
- # return all events if definition isn't specified.
- defer.returnValue(events)
-
- def _filter_with_definition(self, events, definition):
- return [e for e in events if self._passes_definition(definition, e)]
-
- def _passes_definition(self, definition, event):
- """Check if the event passes through the given definition.
-
- Args:
- definition(dict): The definition to check against.
- event(Event): The event to check.
- Returns:
- True if the event passes through the filter.
- """
- # Algorithm notes:
- # For each key in the definition, check the event meets the criteria:
- # * For types: Literal match or prefix match (if ends with wildcard)
- # * For senders/rooms: Literal match only
- # * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
- # and 'not_types' then it is treated as only being in 'not_types')
-
- # room checks
- if hasattr(event, "room_id"):
- room_id = event.room_id
- allow_rooms = definition.get("rooms", None)
- reject_rooms = definition.get("not_rooms", None)
- if reject_rooms and room_id in reject_rooms:
- return False
- if allow_rooms and room_id not in allow_rooms:
- return False
-
- # sender checks
- if hasattr(event, "sender"):
- # Should we be including event.state_key for some event types?
- sender = event.sender
- allow_senders = definition.get("senders", None)
- reject_senders = definition.get("not_senders", None)
- if reject_senders and sender in reject_senders:
- return False
- if allow_senders and sender not in allow_senders:
- return False
-
- # type checks
- if "not_types" in definition:
- for def_type in definition["not_types"]:
- if self._event_matches_type(event, def_type):
- return False
- if "types" in definition:
- included = False
- for def_type in definition["types"]:
- if self._event_matches_type(event, def_type):
- included = True
- break
- if not included:
- return False
-
- return True
-
- def _event_matches_type(self, event, def_type):
- if def_type.endswith("*"):
- type_prefix = def_type[:-1]
- return event.type.startswith(type_prefix)
- else:
- return event.type == def_type
-
- def _check_valid_filter(self, user_filter):
+ def _check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
- user_filter(dict): The filter
+ user_filter_json(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
@@ -162,13 +60,13 @@ class Filtering(object):
]
for key in top_level_definitions:
- if key in user_filter:
- self._check_definition(user_filter[key])
+ if key in user_filter_json:
+ self._check_definition(user_filter_json[key])
- if "room" in user_filter:
+ if "room" in user_filter_json:
for key in room_level_definitions:
- if key in user_filter["room"]:
- self._check_definition(user_filter["room"][key])
+ if key in user_filter_json["room"]:
+ self._check_definition(user_filter_json["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
@@ -237,3 +135,101 @@ class Filtering(object):
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
+
+
+class Filter(object):
+ def __init__(self, filter_json):
+ self.filter_json = filter_json
+
+ def filter_public_user_data(self, events):
+ return self._filter_on_key(events, ["public_user_data"])
+
+ def filter_private_user_data(self, events):
+ return self._filter_on_key(events, ["private_user_data"])
+
+ def filter_room_state(self, events):
+ return self._filter_on_key(events, ["room", "state"])
+
+ def filter_room_events(self, events):
+ return self._filter_on_key(events, ["room", "events"])
+
+ def filter_room_ephemeral(self, events):
+ return self._filter_on_key(events, ["room", "ephemeral"])
+
+ def _filter_on_key(self, events, keys):
+ filter_json = self.filter_json
+ if not filter_json:
+ return events
+
+ try:
+ # extract the right definition from the filter
+ definition = filter_json
+ for key in keys:
+ definition = definition[key]
+ return self._filter_with_definition(events, definition)
+ except KeyError:
+ # return all events if definition isn't specified.
+ return events
+
+ def _filter_with_definition(self, events, definition):
+ return [e for e in events if self._passes_definition(definition, e)]
+
+ def _passes_definition(self, definition, event):
+ """Check if the event passes through the given definition.
+
+ Args:
+ definition(dict): The definition to check against.
+ event(Event): The event to check.
+ Returns:
+ True if the event passes through the filter.
+ """
+ # Algorithm notes:
+ # For each key in the definition, check the event meets the criteria:
+ # * For types: Literal match or prefix match (if ends with wildcard)
+ # * For senders/rooms: Literal match only
+ # * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
+ # and 'not_types' then it is treated as only being in 'not_types')
+
+ # room checks
+ if hasattr(event, "room_id"):
+ room_id = event.room_id
+ allow_rooms = definition.get("rooms", None)
+ reject_rooms = definition.get("not_rooms", None)
+ if reject_rooms and room_id in reject_rooms:
+ return False
+ if allow_rooms and room_id not in allow_rooms:
+ return False
+
+ # sender checks
+ if hasattr(event, "sender"):
+ # Should we be including event.state_key for some event types?
+ sender = event.sender
+ allow_senders = definition.get("senders", None)
+ reject_senders = definition.get("not_senders", None)
+ if reject_senders and sender in reject_senders:
+ return False
+ if allow_senders and sender not in allow_senders:
+ return False
+
+ # type checks
+ if "not_types" in definition:
+ for def_type in definition["not_types"]:
+ if self._event_matches_type(event, def_type):
+ return False
+ if "types" in definition:
+ included = False
+ for def_type in definition["types"]:
+ if self._event_matches_type(event, def_type):
+ included = True
+ break
+ if not included:
+ return False
+
+ return True
+
+ def _event_matches_type(self, event, def_type):
+ if def_type.endswith("*"):
+ type_prefix = def_type[:-1]
+ return event.type.startswith(type_prefix)
+ else:
+ return event.type == def_type
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index cee06ccaca..6ddc495d23 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -59,7 +59,7 @@ class GetFilterRestServlet(RestServlet):
filter_id=filter_id,
)
- defer.returnValue((200, filter))
+ defer.returnValue((200, filter.filter_json))
except KeyError:
raise SynapseError(400, "No such filter")
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index aa93616a9f..babf4c37f1 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -25,6 +25,7 @@ from tests.utils import (
from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.api.filtering import Filter
user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id")
@@ -53,6 +54,7 @@ class FilteringTestCase(unittest.TestCase):
)
self.filtering = hs.get_filtering()
+ self.filter = Filter({})
self.datastore = hs.get_datastore()
@@ -66,7 +68,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_types_works_with_wildcards(self):
@@ -79,7 +81,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_types_works_with_unknowns(self):
@@ -92,7 +94,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_literals(self):
@@ -105,7 +107,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_wildcards(self):
@@ -118,7 +120,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_unknowns(self):
@@ -131,7 +133,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_types_takes_priority_over_types(self):
@@ -145,7 +147,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_senders_works_with_literals(self):
@@ -158,7 +160,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_senders_works_with_unknowns(self):
@@ -171,7 +173,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_literals(self):
@@ -184,7 +186,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_unknowns(self):
@@ -197,7 +199,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_takes_priority_over_senders(self):
@@ -211,7 +213,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_rooms_works_with_literals(self):
@@ -224,7 +226,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown"
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_rooms_works_with_unknowns(self):
@@ -237,7 +239,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_literals(self):
@@ -250,7 +252,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_unknowns(self):
@@ -263,7 +265,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown"
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_takes_priority_over_rooms(self):
@@ -277,7 +279,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown"
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_combined_event(self):
@@ -295,7 +297,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertTrue(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_sender(self):
@@ -313,7 +315,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_room(self):
@@ -331,7 +333,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!piggyshouse:muppets" # nope
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_type(self):
@@ -349,12 +351,12 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup
)
self.assertFalse(
- self.filtering._passes_definition(definition, event)
+ self.filter._passes_definition(definition, event)
)
@defer.inlineCallbacks
def test_filter_public_user_data_match(self):
- user_filter = {
+ user_filter_json = {
"public_user_data": {
"types": ["m.*"]
}
@@ -362,7 +364,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
- user_filter=user_filter,
+ user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
@@ -371,16 +373,17 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- results = yield self.filtering.filter_public_user_data(
- events=events,
- user=user,
- filter_id=filter_id
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
)
+
+ results = user_filter.filter_public_user_data(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_public_user_data_no_match(self):
- user_filter = {
+ user_filter_json = {
"public_user_data": {
"types": ["m.*"]
}
@@ -388,7 +391,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
- user_filter=user_filter,
+ user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
@@ -397,16 +400,17 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- results = yield self.filtering.filter_public_user_data(
- events=events,
- user=user,
- filter_id=filter_id
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
)
+
+ results = user_filter.filter_public_user_data(events=events)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_filter_room_state_match(self):
- user_filter = {
+ user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
@@ -416,7 +420,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
- user_filter=user_filter,
+ user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
@@ -425,16 +429,17 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- results = yield self.filtering.filter_room_state(
- events=events,
- user=user,
- filter_id=filter_id
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
)
+
+ results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
- user_filter = {
+ user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
@@ -444,7 +449,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
- user_filter=user_filter,
+ user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
@@ -453,16 +458,17 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- results = yield self.filtering.filter_room_state(
- events=events,
- user=user,
- filter_id=filter_id
+ user_filter = yield self.filtering.get_user_filter(
+ user_localpart=user_localpart,
+ filter_id=filter_id,
)
+
+ results = user_filter.filter_room_state(events)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_add_filter(self):
- user_filter = {
+ user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
@@ -472,11 +478,11 @@ class FilteringTestCase(unittest.TestCase):
filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
- user_filter=user_filter,
+ user_filter=user_filter_json,
)
self.assertEquals(filter_id, 0)
- self.assertEquals(user_filter,
+ self.assertEquals(user_filter_json,
(yield self.datastore.get_user_filter(
user_localpart=user_localpart,
filter_id=0,
@@ -485,7 +491,7 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_filter(self):
- user_filter = {
+ user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
@@ -495,7 +501,7 @@ class FilteringTestCase(unittest.TestCase):
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
- user_filter=user_filter,
+ user_filter=user_filter_json,
)
filter = yield self.filtering.get_user_filter(
@@ -503,4 +509,4 @@ class FilteringTestCase(unittest.TestCase):
filter_id=filter_id,
)
- self.assertEquals(filter, user_filter)
+ self.assertEquals(filter.filter_json, user_filter_json)
|