diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c4f0bbd3dd..8efd39c7f7 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 67f1013051..5ec568f4e6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
- "get_devices_by_remote",
+ "get_device_updates_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_devices_by_remote.return_value = (0, [])
+ self.datastore.get_device_updates_by_remote.return_value = (0, [])
def get_received_txn_response(*args):
return defer.succeed(None)
@@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
+ self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+ None
+ )
def test_started_typing_local(self):
self.room_members = [U_APPLE, U_BANANA]
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba6464..2096ba3c91 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
+
+
+def get_test_https_policy():
+ """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+ Returns:
+ IPolicyForHTTPS
+ """
+ ca_file = get_test_ca_cert_file()
+ with open(ca_file) as stream:
+ content = stream.read()
+ cert = Certificate.loadPEM(content)
+ trust_root = trustRootFromCertificates([cert])
+ return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 71d7025264..cfcd98ff7d 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
+ # grab a hold of the TLS connection, in case it gets torn down
+ server_tls_connection = server_tls_protocol._tlsConnection
+
+ # fish the test server back out of the server-side TLS protocol.
+ http_protocol = server_tls_protocol.wrappedProtocol
+
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
- server_name = server_tls_protocol._tlsConnection.get_servername()
+ server_name = server_tls_connection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
- # fish the test server back out of the server-side TLS protocol.
- return server_tls_protocol.wrappedProtocol
+ return http_protocol
@defer.inlineCallbacks
def _make_get_request(self, uri):
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 0000000000..22abf76515
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def _make_connection(
+ self, client_factory, server_factory, ssl=False, expected_sni=None
+ ):
+ """Builds a test server, and completes the outgoing client connection
+
+ Args:
+ client_factory (interfaces.IProtocolFactory): the the factory that the
+ application is trying to use to make the outbound connection. We will
+ invoke it to build the client Protocol
+
+ server_factory (interfaces.IProtocolFactory): a factory to build the
+ server-side protocol
+
+ ssl (bool): If true, we will expect an ssl connection and wrap
+ server_factory with a TLSMemoryBIOFactory
+
+ expected_sni (bytes|None): the expected SNI value
+
+ Returns:
+ IProtocol: the server Protocol returned by server_factory
+ """
+ if ssl:
+ server_factory = _wrap_server_factory_for_tls(server_factory)
+
+ server_protocol = server_factory.buildProtocol(None)
+
+ # now, tell the client protocol factory to build the client protocol,
+ # and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(
+ FakeTransport(server_protocol, self.reactor, client_protocol)
+ )
+
+ # tell the server protocol to send its stuff back to the client, too
+ server_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, server_protocol)
+ )
+
+ if ssl:
+ http_protocol = server_protocol.wrappedProtocol
+ tls_connection = server_protocol._tlsConnection
+ else:
+ http_protocol = server_protocol
+ tls_connection = None
+
+ # give the reactor a pump to get the TLS juices flowing (if needed)
+ self.reactor.advance(0)
+
+ if expected_sni is not None:
+ server_name = tls_connection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ return http_protocol
+
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 80)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request(self):
+ agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+ self.reactor.lookups["test.com"] = "1.2.3.4"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=True,
+ expected_sni=b"test.com",
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_http_request_via_proxy(self):
+ agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"http://test.com")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 8888)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"http://test.com")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+ def test_https_request_via_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ https_proxy=b"proxy.com",
+ )
+
+ self.reactor.lookups["proxy.com"] = "1.2.3.5"
+ d = agent.request(b"GET", b"https://test.com/abc")
+
+ # there should be a pending TCP connection
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.5")
+ self.assertEqual(port, 1080)
+
+ # make a test HTTP server, and wire up the client
+ proxy_server = self._make_connection(
+ client_factory, _get_test_protocol_factory()
+ )
+
+ # fish the transports back out so that we can do the old switcheroo
+ s2c_transport = proxy_server.transport
+ client_protocol = s2c_transport.other
+ c2s_transport = client_protocol.transport
+
+ # the FakeTransport is async, so we need to pump the reactor
+ self.reactor.advance(0)
+
+ # now there should be a pending CONNECT request
+ self.assertEqual(len(proxy_server.requests), 1)
+
+ request = proxy_server.requests[0]
+ self.assertEqual(request.method, b"CONNECT")
+ self.assertEqual(request.path, b"test.com:443")
+
+ # tell the proxy server not to close the connection
+ proxy_server.persistent = True
+
+ # this just stops the http Request trying to do a chunked response
+ # request.setHeader(b"Content-Length", b"0")
+ request.finish()
+
+ # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+ ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_protocol = ssl_factory.buildProtocol(None)
+ http_server = ssl_protocol.wrappedProtocol
+
+ ssl_protocol.makeConnection(
+ FakeTransport(client_protocol, self.reactor, ssl_protocol)
+ )
+ c2s_transport.other = ssl_protocol
+
+ self.reactor.advance(0)
+
+ server_name = ssl_protocol._tlsConnection.get_servername()
+ expected_sni = b"test.com"
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # now there should be a pending request
+ self.assertEqual(len(http_server.requests), 1)
+
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/abc")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ request.write(b"result")
+ request.finish()
+
+ self.reactor.advance(0)
+
+ resp = self.successResultOf(d)
+ body = self.successResultOf(treq.content(resp))
+ self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory (interfaces.IProtocolFactory): protocol factory to wrap
+ sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ if sanlist is None:
+ sanlist = [b"DNS:test.com"]
+
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+
+
+def _get_test_protocol_factory():
+ """Get a protocol Factory which will build an HTTPChannel
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ return server_factory
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62da..af2327fb66 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
config = self.default_config()
config["start_pushers"] = True
- hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+ hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
return hs
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 104349cdbd..4f924ce451 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
+ self.storage = hs.get_storage()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index a368117b43..b68e9fe082 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
)
msg, msgctx = self.build_event()
- self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
+ self.get_success(
+ self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ )
self.replicate()
event_source = RoomEventSource(self.hs)
@@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.master_store.persist_events([(event, context)], backfilled=True)
+ self.storage.persistence.persist_events(
+ [(event, context)], backfilled=True
+ )
)
else:
- self.get_success(self.master_store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index d3a4f717f7..8e1ca8b738 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -561,3 +561,81 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["groups"]
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+ """Test /purge_room admin API.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ def test_purge_room(self):
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ # All users have to have left the room.
+ self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/purge_room"
+ request, channel = self.make_request(
+ "POST",
+ url.encode("ascii"),
+ {"room_id": room_id},
+ access_token=self.admin_user_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Test that the following tables have been purged of all rows related to the room.
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ ):
+ count = self.get_success(
+ self.store._simple_select_one_onecol(
+ table="events",
+ keyvalues={"room_id": room_id},
+ retcol="COUNT(*)",
+ desc="test_purge_room",
+ )
+ )
+
+ self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
diff --git a/tests/server.py b/tests/server.py
index e397ebe8fa..f878aeaada 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -161,7 +161,11 @@ def make_request(
path = path.encode("ascii")
# Decorate it to be the full path, if we're using shorthand
- if shorthand and not path.startswith(b"/_matrix"):
+ if (
+ shorthand
+ and not path.startswith(b"/_matrix")
+ and not path.startswith(b"/_synapse")
+ ):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
@@ -391,11 +395,24 @@ class FakeTransport(object):
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
- self.disconnected = True
+
+ # if we still have data to write, delay until that is done
+ if self.buffer:
+ logger.info(
+ "FakeTransport: Delaying disconnect until buffer is flushed"
+ )
+ else:
+ self.disconnected = True
def abortConnection(self):
logger.info("FakeTransport: abortConnection()")
- self.loseConnection()
+
+ if not self.disconnecting:
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(None)
+
+ self.disconnected = True
def pauseProducing(self):
if not self.producer:
@@ -426,6 +443,9 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _produce)
def write(self, byt):
+ if self.disconnecting:
+ raise Exception("Writing to disconnecting FakeTransport")
+
self.buffer = self.buffer + byt
# always actually do the write asynchronously. Some protocols (notably the
@@ -470,6 +490,10 @@ class FakeTransport(object):
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
+ if not self.buffer and self.disconnecting:
+ logger.info("FakeTransport: Buffer now empty, completing disconnect")
+ self.disconnected = True
+
def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
"""
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index dd49a14524..9b81b536f5 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
a.func.prefill(("foo",), ObservableDeferred(d))
- self.assertEquals(a.func("foo"), d.result)
+ self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..6f8d990959 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
- def test_get_devices_by_remote(self):
+ def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100
)
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
- def test_get_devices_by_remote_limited(self):
+ def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
#
# first we should get a single update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
- now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+ now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
- received_device_ids = {update["device_id"] for update in device_updates}
+ received_device_ids = {
+ update["device_id"] for edu_type, update in device_updates
+ }
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 427d3c49c5..4561c3e383 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.store.persist_event(event_1, context_1))
+ self.get_success(self.storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.store.persist_event(event_2, context_2))
+ self.get_success(self.storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1bee45706f..3ddaa151fe 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield self.store.persist_event(
+ yield self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 447a3c6ffb..9ddd17f73d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.store.persist_event(event, context))
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5c2cf3c2db..43200654f1 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_datastore = self.store
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -63,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
builder
)
- yield self.store.persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@@ -82,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups_ids(
+ state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
)
self.assertEqual(len(state_group_map), 1)
@@ -101,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
+ state_group_map = yield self.storage.state.get_state_groups(
+ self.room, [e2.event_id]
+ )
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -141,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield self.store.get_state_for_event(e5.event_id)
+ state = yield self.storage.state.get_state_for_event(e5.event_id)
self.assertIsNotNone(e4)
@@ -157,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -181,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
@@ -199,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield self.store.get_state_for_event(
+ state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -215,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+ group_ids = yield self.storage.state.get_state_groups_ids(
+ room_id, [e5.event_id]
+ )
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -237,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -250,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -267,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -287,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -304,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -317,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -331,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, True)
self.assertEqual(known_absent, set())
@@ -346,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
state_dict_ids.pop((e2.type, e2.state_key))
- self.store._state_group_cache.invalidate(group)
- self.store._state_group_cache.update(
- sequence=self.store._state_group_cache.sequence,
+ self.state_datastore._state_group_cache.invalidate(group)
+ self.state_datastore._state_group_cache.update(
+ sequence=self.state_datastore._state_group_cache.sequence,
key=group,
value=state_dict_ids,
# list fetched keys so it knows it's partial
fetched_keys=((e1.type, e1.state_key),),
)
- (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
- group
- )
+ (
+ is_all,
+ known_absent,
+ state_dict_ids,
+ ) = self.state_datastore._state_group_cache.get(group)
self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
@@ -369,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -381,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
@@ -394,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -405,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
@@ -424,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -435,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -448,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -459,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
- self.store._state_group_members_cache,
+ (
+ state_dict,
+ is_all,
+ ) = yield self.state_datastore._get_state_for_group_using_cache(
+ self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index a73f18f88e..7d82b58466 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0]
@@ -58,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
)
self.handler = self.homeserver.get_handlers().federation_handler
- self.handler.do_auth = lambda *a, **b: succeed(True)
+ self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+ context
+ )
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
@@ -75,7 +78,8 @@ class MessageAcceptTests(unittest.TestCase):
self.assertEqual(
self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0],
"$join:test.serv",
@@ -97,7 +101,8 @@ class MessageAcceptTests(unittest.TestCase):
# Figure out what the most recent event is
most_recent = self.successResultOf(
maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room,
+ self.room_id,
)
)[0]
@@ -137,6 +142,6 @@ class MessageAcceptTests(unittest.TestCase):
# Make sure the invalid event isn't there
extrem = maybeDeferred(
- self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+ self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb46..38246555bd 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -158,10 +158,12 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = StateGroupStore()
+ storage = Mock(main=self.store, state=self.store)
hs = Mock(
spec_set=[
"config",
"get_datastore",
+ "get_storage",
"get_auth",
"get_state_handler",
"get_clock",
@@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+ hs.get_storage.return_value = storage
self.state = StateHandler(hs)
self.event_id = 0
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 18f1a0035d..f7381b2885 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
# limitations under the License.
import logging
+from mock import Mock
+
from twisted.internet import defer
from twisted.internet.defer import succeed
@@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
@@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens.
filtered = yield filter_events_for_server(
- self.store, "test_server", events_to_filter
+ self.storage, "test_server", events_to_filter
)
for i in range(0, len(events_to_filter)):
@@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
builder
)
- yield self.hs.get_datastore().persist_event(event, context)
+ yield self.storage.persistence.persist_event(event, context)
return event
@defer.inlineCallbacks
@@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering")
start = time.time()
+
+ storage = Mock()
+ storage.main = test_store
+ storage.state = test_store
+
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter
)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index f907903511..39e360fe24 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase):
obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
- self.assertEqual(r, ["spam", "eggs"])
+ self.assertEqual(r.result, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()
# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
- self.assertEqual(r, ["chips"])
+ self.assertEqual(r.result, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()
diff --git a/tests/utils.py b/tests/utils.py
index 8cced4b7e8..7dc9bdc505 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -325,10 +325,16 @@ def setup_test_homeserver(
if homeserverToUse.__name__ == "TestHomeServer":
hs.setup_master()
else:
+ # If we have been given an explicit datastore we probably want to mock
+ # out the DataStores somehow too. This all feels a bit wrong, but then
+ # mocking the stores feels wrong too.
+ datastores = Mock(datastore=datastore)
+
hs = homeserverToUse(
name,
db_pool=None,
datastore=datastore,
+ datastores=datastores,
config=config,
version_string="Synapse/tests",
database_engine=db_engine,
@@ -646,7 +652,7 @@ def create_room(hs, room_id, creator_id):
creator_id (str)
"""
- store = hs.get_datastore()
+ persistence_store = hs.get_storage().persistence
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
@@ -663,4 +669,4 @@ def create_room(hs, room_id, creator_id):
event, context = yield event_creation_handler.create_new_client_event(builder)
- yield store.persist_event(event, context)
+ yield persistence_store.persist_event(event, context)
|