summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py252
1 files changed, 252 insertions, 0 deletions
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000000..13f6b31c9a
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,252 @@
+from synapse.http.server import HttpServer
+from synapse.api.errors import cs_error, CodeMessageException, StoreError
+from synapse.api.constants import Membership
+
+from synapse.api.events.room import (
+    RoomMemberEvent, MessageEvent
+)
+
+from twisted.internet import defer
+
+from collections import namedtuple
+from mock import patch, Mock
+import json
+import urlparse
+
+
+class MockHttpServer(HttpServer):
+
+    def __init__(self, prefix=""):
+        self.callbacks = []  # 3-tuple of method/pattern/function
+        self.prefix = prefix
+
+    def trigger_get(self, path):
+        return self.trigger("GET", path, None)
+
+    @patch('twisted.web.http.Request')
+    @defer.inlineCallbacks
+    def trigger(self, http_method, path, content, mock_request):
+        """ Fire an HTTP event.
+
+        Args:
+            http_method : The HTTP method
+            path : The HTTP path
+            content : The HTTP body
+            mock_request : Mocked request to pass to the event so it can get
+                           content.
+        Returns:
+            A tuple of (code, response)
+        Raises:
+            KeyError If no event is found which will handle the path.
+        """
+        path = self.prefix + path
+
+        # annoyingly we return a twisted http request which has chained calls
+        # to get at the http content, hence mock it here.
+        mock_content = Mock()
+        config = {'read.return_value': content}
+        mock_content.configure_mock(**config)
+        mock_request.content = mock_content
+
+        # return the right path if the event requires it
+        mock_request.path = path
+
+        # add in query params to the right place
+        try:
+            mock_request.args = urlparse.parse_qs(path.split('?')[1])
+            mock_request.path = path.split('?')[0]
+            path = mock_request.path
+        except:
+            pass
+
+        for (method, pattern, func) in self.callbacks:
+            if http_method != method:
+                continue
+
+            matcher = pattern.match(path)
+            if matcher:
+                try:
+                    (code, response) = yield func(
+                        mock_request,
+                        *matcher.groups()
+                    )
+                    defer.returnValue((code, response))
+                except CodeMessageException as e:
+                    defer.returnValue((e.code, cs_error(e.msg)))
+
+        raise KeyError("No event can handle %s" % path)
+
+    def register_path(self, method, path_pattern, callback):
+        self.callbacks.append((method, path_pattern, callback))
+
+
+class MemoryDataStore(object):
+
+    class RoomMember(namedtuple(
+        "RoomMember",
+        ["room_id", "user_id", "sender", "membership", "content"]
+    )):
+        def as_event(self, event_factory):
+            return event_factory.create_event(
+                etype=RoomMemberEvent.TYPE,
+                room_id=self.room_id,
+                target_user_id=self.user_id,
+                user_id=self.sender,
+                content=json.loads(self.content),
+            )
+
+    PathData = namedtuple("PathData",
+                          ["room_id", "path", "content"])
+
+    Message = namedtuple("Message",
+                         ["room_id", "msg_id", "user_id", "content"])
+
+    Room = namedtuple("Room",
+                      ["room_id", "is_public", "creator"])
+
+    def __init__(self):
+        self.tokens_to_users = {}
+        self.paths_to_content = {}
+        self.members = {}
+        self.messages = {}
+        self.rooms = {}
+        self.room_members = {}
+
+    def register(self, user_id, token, password_hash):
+        if user_id in self.tokens_to_users.values():
+            raise StoreError(400, "User in use.")
+        self.tokens_to_users[token] = user_id
+
+    def get_user_by_token(self, token):
+        try:
+            return self.tokens_to_users[token]
+        except:
+            raise StoreError(400, "User does not exist.")
+
+    def get_room(self, room_id):
+        try:
+            return self.rooms[room_id]
+        except:
+            return None
+
+    def store_room(self, room_id, room_creator_user_id, is_public):
+        if room_id in self.rooms:
+            raise StoreError(409, "Conflicting room!")
+
+        room = MemoryDataStore.Room(room_id=room_id, is_public=is_public,
+                    creator=room_creator_user_id)
+        self.rooms[room_id] = room
+        #self.store_room_member(user_id=room_creator_user_id, room_id=room_id,
+                               #membership=Membership.JOIN,
+                               #content={"membership": Membership.JOIN})
+
+    def get_message(self, user_id=None, room_id=None, msg_id=None):
+        try:
+            return self.messages[user_id + room_id + msg_id]
+        except:
+            return None
+
+    def store_message(self, user_id=None, room_id=None, msg_id=None,
+                      content=None):
+        msg = MemoryDataStore.Message(room_id=room_id, msg_id=msg_id,
+                    user_id=user_id, content=content)
+        self.messages[user_id + room_id + msg_id] = msg
+
+    def get_room_member(self, user_id=None, room_id=None):
+        try:
+            return self.members[user_id + room_id]
+        except:
+            return None
+
+    def get_room_members(self, room_id=None, membership=None):
+        try:
+            return self.room_members[room_id]
+        except:
+            return None
+
+    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
+        return [r for r in self.room_members
+                if user_id in self.room_members[r]]
+
+    def store_room_member(self, user_id=None, sender=None, room_id=None,
+                          membership=None, content=None):
+        member = MemoryDataStore.RoomMember(room_id=room_id, user_id=user_id,
+            sender=sender, membership=membership, content=json.dumps(content))
+        self.members[user_id + room_id] = member
+
+        # TODO should be latest state
+        if room_id not in self.room_members:
+            self.room_members[room_id] = []
+        self.room_members[room_id].append(member)
+
+    def get_room_data(self, room_id, etype, state_key=""):
+        path = "%s-%s-%s" % (room_id, etype, state_key)
+        try:
+            return self.paths_to_content[path]
+        except:
+            return None
+
+    def store_room_data(self, room_id, etype, state_key="", content=None):
+        path = "%s-%s-%s" % (room_id, etype, state_key)
+        data = MemoryDataStore.PathData(path=path, room_id=room_id,
+                    content=content)
+        self.paths_to_content[path] = data
+
+    def get_message_stream(self, user_id=None, from_key=None, to_key=None,
+                            room_id=None, limit=0, with_feedback=False):
+        return ([], from_key)  # TODO
+
+    def get_room_member_stream(self, user_id=None, from_key=None, to_key=None):
+        return ([], from_key)  # TODO
+
+    def get_feedback_stream(self, user_id=None, from_key=None, to_key=None,
+                            room_id=None, limit=0):
+        return ([], from_key)  # TODO
+
+    def get_room_data_stream(self, user_id=None, from_key=None, to_key=None,
+                            room_id=None, limit=0):
+        return ([], from_key)  # TODO
+
+    def to_events(self, data_store_list):
+        return data_store_list  # TODO
+
+    def get_max_message_id(self):
+        return 0  # TODO
+
+    def get_max_feedback_id(self):
+        return 0  # TODO
+
+    def get_max_room_member_id(self):
+        return 0  # TODO
+
+    def get_max_room_data_id(self):
+        return 0  # TODO
+
+    def get_joined_hosts_for_room(self, room_id):
+        return defer.succeed([])
+
+    def persist_event(self, event):
+        if event.type == MessageEvent.TYPE:
+            return self.store_message(
+                user_id=event.user_id,
+                room_id=event.room_id,
+                msg_id=event.msg_id,
+                content=json.dumps(event.content)
+            )
+        elif event.type == RoomMemberEvent.TYPE:
+            return self.store_room_member(
+                user_id=event.target_user_id,
+                room_id=event.room_id,
+                content=event.content,
+                membership=event.content["membership"]
+            )
+        else:
+            raise NotImplementedError(
+                "Don't know how to persist type=%s" % event.type
+            )
+
+    def set_presence_state(self, user_localpart, state):
+        return defer.succeed({"state": 0})
+
+    def get_presence_list(self, user_localpart, accepted):
+        return []