summary refs log blame commit diff
path: root/tests/utils.py
blob: 60fd6085ac6f892a787c4f994bda7da01101f5a4 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
                       
                               











                                                                          

                                                                         
                                            



                                     
                                           
                                                    




                                  








                                                                          
 
                                                































                                                                             

                                         


                                                                






























                                                                     













                                   







                                 


                                 
 










                                                       
                              


                                           


                                  
 
                         
                       

                               
 

                                                                               




                                                                               





                                                      



                                                    











                                                                   



                                        
                                  
 
                                                         
 






                                                                   
                                                                               
                

                                                                      
         
 
                                                                              

                                                                        


                                                 
                                              
                                  










                                                                        

                                                    
             


                                             




                                                          

                                     
 











                                                   
                                      
                                       
 
                                                                              
                                  
 
 
                               
                                    










                                                                              
                       
                                        
                                         








                                                                             
                                                               

                                      






                                         


                                                                  



















                                                                           










                                                                               
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from synapse.http.server import HttpServer
from synapse.api.errors import cs_error, CodeMessageException, StoreError
from synapse.api.constants import Membership
from synapse.storage import prepare_database

from synapse.api.events.room import (
    RoomMemberEvent, MessageEvent
)

from twisted.internet import defer, reactor
from twisted.enterprise.adbapi import ConnectionPool

from collections import namedtuple
from mock import patch, Mock
import json
import urlparse

from inspect import getcallargs


def get_mock_call_args(pattern_func, mock_func):
    """ Return the arguments the mock function was called with interpreted
    by the pattern functions argument list.
    """
    invoked_args, invoked_kargs = mock_func.call_args
    return getcallargs(pattern_func, *invoked_args, **invoked_kargs)


# This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer):

    def __init__(self, prefix=""):
        self.callbacks = []  # 3-tuple of method/pattern/function
        self.prefix = prefix

    def trigger_get(self, path):
        return self.trigger("GET", path, None)

    @patch('twisted.web.http.Request')
    @defer.inlineCallbacks
    def trigger(self, http_method, path, content, mock_request):
        """ Fire an HTTP event.

        Args:
            http_method : The HTTP method
            path : The HTTP path
            content : The HTTP body
            mock_request : Mocked request to pass to the event so it can get
                           content.
        Returns:
            A tuple of (code, response)
        Raises:
            KeyError If no event is found which will handle the path.
        """
        path = self.prefix + path

        # annoyingly we return a twisted http request which has chained calls
        # to get at the http content, hence mock it here.
        mock_content = Mock()
        config = {'read.return_value': content}
        mock_content.configure_mock(**config)
        mock_request.content = mock_content

        mock_request.method = http_method
        mock_request.uri = path

        mock_request.requestHeaders.getRawHeaders.return_value=[
            "X-Matrix origin=test,key=,sig="
        ]

        # return the right path if the event requires it
        mock_request.path = path

        # add in query params to the right place
        try:
            mock_request.args = urlparse.parse_qs(path.split('?')[1])
            mock_request.path = path.split('?')[0]
            path = mock_request.path
        except:
            pass

        for (method, pattern, func) in self.callbacks:
            if http_method != method:
                continue

            matcher = pattern.match(path)
            if matcher:
                try:
                    (code, response) = yield func(
                        mock_request,
                        *matcher.groups()
                    )
                    defer.returnValue((code, response))
                except CodeMessageException as e:
                    defer.returnValue((e.code, cs_error(e.msg)))

        raise KeyError("No event can handle %s" % path)

    def register_path(self, method, path_pattern, callback):
        self.callbacks.append((method, path_pattern, callback))


class MockKey(object):
    alg = "mock_alg"
    version = "mock_version"

    @property
    def verify_key(self):
        return self

    def sign(self, message):
        return b"\x9a\x87$"

    def verify(self, message, sig):
        assert sig == b"\x9a\x87$"


class MockClock(object):
    now = 1000

    def time(self):
        return self.now

    def time_msec(self):
        return self.time() * 1000

    # For unit testing
    def advance_time(self, secs):
        self.now += secs


class SQLiteMemoryDbPool(ConnectionPool, object):
    def __init__(self):
        super(SQLiteMemoryDbPool, self).__init__(
            "sqlite3", ":memory:",
            cp_min=1,
            cp_max=1,
        )

    def prepare(self):
        return self.runWithConnection(prepare_database)


class MemoryDataStore(object):

    Room = namedtuple(
        "Room",
        ["room_id", "is_public", "creator"]
    )

    def __init__(self):
        self.tokens_to_users = {}
        self.paths_to_content = {}

        self.members = {}
        self.rooms = {}

        self.current_state = {}
        self.events = []

    class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
        def fill_out_prev_events(self, event):
            pass

    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
        return self.Snapshot(
            room_id, user_id, self.get_room_member(user_id, room_id)
        )

    def register(self, user_id, token, password_hash):
        if user_id in self.tokens_to_users.values():
            raise StoreError(400, "User in use.")
        self.tokens_to_users[token] = user_id

    def get_user_by_token(self, token):
        try:
            return {
                "name": self.tokens_to_users[token],
                "admin": 0,
                "device_id": None,
            }
        except:
            raise StoreError(400, "User does not exist.")

    def get_room(self, room_id):
        try:
            return self.rooms[room_id]
        except:
            return None

    def store_room(self, room_id, room_creator_user_id, is_public):
        if room_id in self.rooms:
            raise StoreError(409, "Conflicting room!")

        room = MemoryDataStore.Room(
            room_id=room_id,
            is_public=is_public,
            creator=room_creator_user_id
        )
        self.rooms[room_id] = room

    def get_room_member(self, user_id, room_id):
        return self.members.get(room_id, {}).get(user_id)

    def get_room_members(self, room_id, membership=None):
        if membership:
            return [
                v for k, v in self.members.get(room_id, {}).items()
                if v.membership == membership
            ]
        else:
            return self.members.get(room_id, {}).values()

    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
        return [
            self.members[r].get(user_id) for r in self.members
            if user_id in self.members[r] and
                self.members[r][user_id].membership in membership_list
        ]

    def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
                            room_id=None, limit=0, with_feedback=False):
        return ([], from_key)  # TODO

    def get_joined_hosts_for_room(self, room_id):
        return defer.succeed([])

    def persist_event(self, event):
        if event.type == RoomMemberEvent.TYPE:
            room_id = event.room_id
            user = event.state_key
            membership = event.membership
            self.members.setdefault(room_id, {})[user] = event

        if hasattr(event, "state_key"):
            key = (event.room_id, event.type, event.state_key)
            self.current_state[key] = event

        self.events.append(event)

    def get_current_state(self, room_id, event_type=None, state_key=""):
        if event_type:
            key = (room_id, event_type, state_key)
            if self.current_state.get(key):
                return [self.current_state.get(key)]
            return None
        else:
            return [
                e for e in self.current_state
                if e[0] == room_id
            ]

    def set_presence_state(self, user_localpart, state):
        return defer.succeed({"state": 0})

    def get_presence_list(self, user_localpart, accepted):
        return []

    def get_room_events_max_id(self):
        return 0  # TODO (erikj)

    def get_send_event_level(self, room_id):
        return defer.succeed(0)

    def get_power_level(self, room_id, user_id):
        return defer.succeed(0)

    def get_add_state_level(self, room_id):
        return defer.succeed(0)

    def get_room_join_rule(self, room_id):
        # TODO (erikj): This should be configurable
        return defer.succeed("invite")

    def get_ops_levels(self, room_id):
        return defer.succeed((5, 5, 5))

    def insert_client_ip(self, user, device_id, access_token, ip, user_agent):
        return defer.succeed(None)


def _format_call(args, kwargs):
    return ", ".join(
        ["%r" % (a) for a in args] +
        ["%s=%r" % (k, v) for k, v in kwargs.items()]
    )


class DeferredMockCallable(object):
    """A callable instance that stores a set of pending call expectations and
    return values for them. It allows a unit test to assert that the given set
    of function calls are eventually made, by awaiting on them to be called.
    """

    def __init__(self):
        self.expectations = []
        self.calls = []

    def __call__(self, *args, **kwargs):
        self.calls.append((args, kwargs))

        if not self.expectations:
            raise ValueError("%r has no pending calls to handle call(%s)" % (
                self, _format_call(args, kwargs))
            )

        for (call, result, d) in self.expectations:
            if args == call[1] and kwargs == call[2]:
                d.callback(None)
                return result

        failure = AssertionError("Was not expecting call(%s)" %
            _format_call(args, kwargs)
        )

        for _, _, d in self.expectations:
            try:
                d.errback(failure)
            except:
                pass

        raise failure

    def expect_call_and_return(self, call, result):
        self.expectations.append((call, result, defer.Deferred()))

    @defer.inlineCallbacks
    def await_calls(self, timeout=1000):
        deferred = defer.DeferredList(
            [d for _, _, d in self.expectations],
            fireOnOneErrback=True
        )

        timer = reactor.callLater(
            timeout/1000,
            deferred.errback,
            AssertionError(
                "%d pending calls left: %s"% (
                    len([e for e in self.expectations if not e[2].called]),
                    [e for e in self.expectations if not e[2].called]
                )
            )
        )

        yield deferred

        timer.cancel()

        self.calls = []

    def assert_had_no_calls(self):
        if self.calls:
            calls = self.calls
            self.calls = []

            raise AssertionError("Expected not to received any calls, got:\n" +
                "\n".join([
                    "call(%s)" % _format_call(c[0], c[1]) for c in calls
                ])
            )