diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
new file mode 100644
index 0000000000..922c40004c
--- /dev/null
+++ b/synapse/api/filtering.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# TODO(paul)
+_filters_for_user = {}
+
+
+class Filtering(object):
+
+ def __init__(self, hs):
+ super(Filtering, self).__init__()
+ self.hs = hs
+
+ def get_user_filter(self, user_localpart, filter_id):
+ filters = _filters_for_user.get(user_localpart, None)
+
+ if not filters or filter_id >= len(filters):
+ raise KeyError()
+
+ return filters[filter_id]
+
+ def add_user_filter(self, user_localpart, definition):
+ filters = _filters_for_user.setdefault(user_localpart, [])
+
+ filter_id = len(filters)
+ filters.append(definition)
+
+ return filter_id
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index a9a180ec04..585c8e02e8 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -28,10 +28,6 @@ import logging
logger = logging.getLogger(__name__)
-# TODO(paul)
-_filters_for_user = {}
-
-
class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
@@ -39,6 +35,7 @@ class GetFilterRestServlet(RestServlet):
super(GetFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
+ self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
@@ -56,13 +53,14 @@ class GetFilterRestServlet(RestServlet):
except:
raise SynapseError(400, "Invalid filter_id")
- filters = _filters_for_user.get(target_user.localpart, None)
-
- if not filters or filter_id >= len(filters):
+ try:
+ defer.returnValue((200, self.filtering.get_user_filter(
+ user_localpart=target_user.localpart,
+ filter_id=filter_id,
+ )))
+ except KeyError:
raise SynapseError(400, "No such filter")
- defer.returnValue((200, filters[filter_id]))
-
class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
@@ -71,6 +69,7 @@ class CreateFilterRestServlet(RestServlet):
super(CreateFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
+ self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_POST(self, request, user_id):
@@ -90,10 +89,10 @@ class CreateFilterRestServlet(RestServlet):
except:
raise SynapseError(400, "Invalid filter definition")
- filters = _filters_for_user.setdefault(target_user.localpart, [])
-
- filter_id = len(filters)
- filters.append(content)
+ filter_id = self.filtering.add_user_filter(
+ user_localpart=target_user.localpart,
+ definition=content,
+ )
defer.returnValue((200, {"filter_id": str(filter_id)}))
diff --git a/synapse/server.py b/synapse/server.py
index f09d5d581e..9b42079e05 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -32,6 +32,7 @@ from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
+from synapse.api.filtering import Filtering
class BaseHomeServer(object):
@@ -79,6 +80,7 @@ class BaseHomeServer(object):
'ratelimiter',
'keyring',
'event_builder_factory',
+ 'filtering',
]
def __init__(self, hostname, **kwargs):
@@ -197,3 +199,6 @@ class HomeServer(BaseHomeServer):
clock=self.get_clock(),
hostname=self.hostname,
)
+
+ def build_filtering(self):
+ return Filtering(self)
|