diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 75ae740b43..08214b0013 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_id, stream_ordering = self.get_success(
self.hs.get_datastores().main.db_pool.execute(
"test:get_destination_rooms",
- None,
"""
SELECT event_id, stream_ordering
FROM destination_rooms dr
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 867dbd6001..78646cb5dc 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -31,7 +31,12 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType
+from synapse.types import (
+ JsonDict,
+ MultiWriterStreamToken,
+ RoomStreamToken,
+ StreamKeyType,
+)
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -156,6 +161,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
result = self.successResultOf(
defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
)
+ assert result is not None
self.mock_as_api.query_alias.assert_called_once_with(
interested_service, room_alias_str
@@ -304,7 +310,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
- StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
+ StreamKeyType.RECEIPT,
+ MultiWriterStreamToken(stream=580),
+ ["@fakerecipient:example.com"],
)
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, ephemeral=[event]
@@ -332,7 +340,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.handler.notify_interested_services_ephemeral(
- StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"]
+ StreamKeyType.RECEIPT,
+ MultiWriterStreamToken(stream=580),
+ ["@fakerecipient:example.com"],
)
# This method will be called, but with an empty list of events
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -635,7 +645,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
services=[interested_appservice],
stream_key=StreamKeyType.RECEIPT,
- new_token=stream_token,
+ new_token=MultiWriterStreamToken(stream=stream_token),
users=[self.exclusive_as_user],
)
)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 41c8c44e02..173b14521a 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import itertools
from typing import Optional, cast
from unittest.mock import Mock, call
@@ -33,6 +33,7 @@ from synapse.handlers.presence import (
IDLE_TIMER,
LAST_ACTIVE_GRANULARITY,
SYNC_ONLINE_TIMEOUT,
+ PresenceHandler,
handle_timeout,
handle_update,
)
@@ -66,7 +67,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
)
state, persist_and_notify, federation_ping = handle_update(
- prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
)
self.assertTrue(persist_and_notify)
@@ -108,7 +114,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
)
state, persist_and_notify, federation_ping = handle_update(
- prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
)
self.assertFalse(persist_and_notify)
@@ -153,7 +164,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
)
state, persist_and_notify, federation_ping = handle_update(
- prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
)
self.assertFalse(persist_and_notify)
@@ -196,7 +212,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
state, persist_and_notify, federation_ping = handle_update(
- prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
)
self.assertTrue(persist_and_notify)
@@ -231,7 +252,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
state, persist_and_notify, federation_ping = handle_update(
- prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now
+ prev_state,
+ new_state,
+ is_mine=False,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
)
self.assertFalse(persist_and_notify)
@@ -265,7 +291,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
new_state = prev_state.copy_and_replace(state=PresenceState.OFFLINE)
state, persist_and_notify, federation_ping = handle_update(
- prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
)
self.assertTrue(persist_and_notify)
@@ -287,7 +318,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
new_state = prev_state.copy_and_replace(state=PresenceState.UNAVAILABLE)
state, persist_and_notify, federation_ping = handle_update(
- prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
)
self.assertTrue(persist_and_notify)
@@ -347,6 +383,41 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
# They should be identical.
self.assertEqual(presence_states_compare, db_presence_states)
+ @parameterized.expand(
+ itertools.permutations(
+ (
+ PresenceState.BUSY,
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ ),
+ 2,
+ )
+ )
+ def test_override(self, initial_state: str, final_state: str) -> None:
+ """Overridden statuses should not go into the wheel timer."""
+ wheel_timer = Mock()
+ user_id = "@foo:bar"
+ now = 5000000
+
+ prev_state = UserPresenceState.default(user_id)
+ prev_state = prev_state.copy_and_replace(
+ state=initial_state, last_active_ts=now, currently_active=True
+ )
+
+ new_state = prev_state.copy_and_replace(state=final_state, last_active_ts=now)
+
+ handle_update(
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=True,
+ )
+
+ wheel_timer.insert.assert_not_called()
+
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""
@@ -738,7 +809,6 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
- self.clock = hs.get_clock()
def test_external_process_timeout(self) -> None:
"""Test that if an external process doesn't update the records for a while
@@ -1471,6 +1541,29 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(new_state.state, state)
self.assertEqual(new_state.status_msg, status_msg)
+ @unittest.override_config({"presence": {"enabled": "untracked"}})
+ def test_untracked_does_not_idle(self) -> None:
+ """Untracked presence should not idle."""
+
+ # Mark user as online, this needs to reach into internals in order to
+ # bypass checks.
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+ assert isinstance(self.presence_handler, PresenceHandler)
+ self.get_success(
+ self.presence_handler._update_states(
+ [state.copy_and_replace(state=PresenceState.ONLINE)]
+ )
+ )
+
+ # Ensure the update took.
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
+ # The timeout should not fire and the state should be the same.
+ self.reactor.advance(SYNC_ONLINE_TIMEOUT)
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d11ded6c5b..76c56d5434 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
@@ -68,10 +68,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- async def get_all_room_state(self) -> List[Dict[str, Any]]:
- return await self.store.db_pool.simple_select_list(
- "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
+ async def get_all_room_state(self) -> List[Optional[str]]:
+ rows = cast(
+ List[Tuple[Optional[str]]],
+ await self.store.db_pool.simple_select_list(
+ "room_stats_state", None, retcols=("topic",)
+ ),
)
+ return [r[0] for r in rows]
def _get_current_stats(
self, stats_type: str, stat_id: str
@@ -130,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r = self.get_success(self.get_all_room_state())
self.assertEqual(len(r), 1)
- self.assertEqual(r[0]["topic"], "foo")
+ self.assertEqual(r[0], "foo")
def test_create_user(self) -> None:
"""
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 528cdee34b..d5306e7ee0 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -15,14 +15,20 @@ import os.path
import subprocess
from typing import List
+from incremental import Version
from zope.interface import implementer
+import twisted
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.address import IPv4Address
-from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.interfaces import (
+ IOpenSSLServerConnectionCreator,
+ IProtocolFactory,
+ IReactorTime,
+)
from twisted.internet.ssl import Certificate, trustRootFromCertificates
-from twisted.protocols.tls import TLSMemoryBIOProtocol
+from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401
@@ -153,6 +159,33 @@ class TestServerTLSConnectionFactory:
return Connection(ctx, None)
+def wrap_server_factory_for_tls(
+ factory: IProtocolFactory, clock: IReactorTime, sanlist: List[bytes]
+) -> TLSMemoryBIOFactory:
+ """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+ The resultant factory will create a TLS server which presents a certificate
+ signed by our test CA, valid for the domains in `sanlist`
+
+ Args:
+ factory: protocol factory to wrap
+ sanlist: list of domains the cert should be valid for
+
+ Returns:
+ interfaces.IProtocolFactory
+ """
+ connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+ # Twisted > 23.8.0 has a different API that accepts a clock.
+ if twisted.version <= Version("Twisted", 23, 8, 0):
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory
+ )
+ else:
+ return TLSMemoryBIOFactory(
+ connection_creator, isClient=False, wrappedFactory=factory, clock=clock # type: ignore[call-arg]
+ )
+
+
# A dummy address, useful for tests that use FakeTransport and don't care about where
# packets are going to/coming from.
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 9f63fa6fa8..0f623ae50b 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -31,7 +31,7 @@ from twisted.internet.interfaces import (
IProtocolFactory,
)
from twisted.internet.protocol import Factory, Protocol
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
from twisted.web.http import HTTPChannel, Request
@@ -57,11 +57,7 @@ from synapse.types import ISynapseReactor
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
-from tests.http import (
- TestServerTLSConnectionFactory,
- dummy_address,
- get_test_ca_cert_file,
-)
+from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.utils import checked_cast, default_config
@@ -125,7 +121,18 @@ class MatrixFederationAgentTests(unittest.TestCase):
# build the test server
server_factory = _get_test_protocol_factory()
if ssl:
- server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
+ server_factory = wrap_server_factory_for_tls(
+ server_factory,
+ self.reactor,
+ tls_sanlist
+ or [
+ b"DNS:testserv",
+ b"DNS:target-server",
+ b"DNS:xn--bcher-kva.com",
+ b"IP:1.2.3.4",
+ b"IP:::1",
+ ],
+ )
server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
@@ -435,8 +442,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
request.finish()
# now we make another test server to act as the upstream HTTP server.
- server_ssl_protocol = _wrap_server_factory_for_tls(
- _get_test_protocol_factory()
+ server_ssl_protocol = wrap_server_factory_for_tls(
+ _get_test_protocol_factory(),
+ self.reactor,
+ sanlist=[
+ b"DNS:testserv",
+ b"DNS:target-server",
+ b"DNS:xn--bcher-kva.com",
+ b"IP:1.2.3.4",
+ b"IP:::1",
+ ],
).buildProtocol(dummy_address)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
@@ -1786,33 +1801,6 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))
-def _wrap_server_factory_for_tls(
- factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> TLSMemoryBIOFactory:
- """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
- The resultant factory will create a TLS server which presents a certificate
- signed by our test CA, valid for the domains in `sanlist`
- Args:
- factory: protocol factory to wrap
- sanlist: list of domains the cert should be valid for
- Returns:
- interfaces.IProtocolFactory
- """
- if sanlist is None:
- sanlist = [
- b"DNS:testserv",
- b"DNS:target-server",
- b"DNS:xn--bcher-kva.com",
- b"IP:1.2.3.4",
- b"IP:::1",
- ]
-
- connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
- return TLSMemoryBIOFactory(
- connection_creator, isClient=False, wrappedFactory=factory
- )
-
-
def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel
Returns:
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 36472e57a8..d524c183f8 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -335,7 +335,7 @@ class Deferred__next__Patch:
self._request_number = request_number
self._seen_awaits = seen_awaits
- self._original_Deferred___next__ = Deferred.__next__
+ self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore]
# The number of `await`s on `Deferred`s we have seen so far.
self.awaits_seen = 0
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index ab94f3f67a..bf1d287699 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -70,7 +70,7 @@ class FederationClientTests(HomeserverTestCase):
"""
@defer.inlineCallbacks
- def do_request() -> Generator["Deferred[object]", object, object]:
+ def do_request() -> Generator["Deferred[Any]", object, object]:
with LoggingContext("one") as context:
fetch_d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar")
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 8164b0b78e..1f117276cf 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -29,18 +29,14 @@ from twisted.internet.endpoints import (
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory, Protocol
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
from synapse.http.client import BlocklistingReactorWrapper
from synapse.http.connectproxyclient import BasicProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy
-from tests.http import (
- TestServerTLSConnectionFactory,
- dummy_address,
- get_test_https_policy,
-)
+from tests.http import dummy_address, get_test_https_policy, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
from tests.utils import checked_cast
@@ -217,6 +213,27 @@ class ProxyParserTests(TestCase):
)
+class TestBasicProxyCredentials(TestCase):
+ def test_long_user_pass_string_encoded_without_newlines(self) -> None:
+ """Reproduces https://github.com/matrix-org/synapse/pull/16504."""
+ proxy_connection_string = b"looooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooonguser:pass@proxy.local:9988"
+ _, _, _, creds = parse_proxy(proxy_connection_string)
+ assert creds is not None # for mypy's benefit
+ self.assertIsInstance(creds, BasicProxyCredentials)
+
+ auth_value = creds.as_proxy_authorization_value()
+ self.assertNotIn(b"\n", auth_value)
+ self.assertEqual(
+ creds.as_proxy_authorization_value(),
+ b"Basic bG9vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vb29vbmd1c2VyOnBhc3M=",
+ )
+ basic_auth_payload = creds.as_proxy_authorization_value().split(b" ")[1]
+ self.assertEqual(
+ base64.b64decode(basic_auth_payload),
+ b"looooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooonguser:pass",
+ )
+
+
class MatrixFederationAgentTests(TestCase):
def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
@@ -251,7 +268,9 @@ class MatrixFederationAgentTests(TestCase):
the server Protocol returned by server_factory
"""
if ssl:
- server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
+ server_factory = wrap_server_factory_for_tls(
+ server_factory, self.reactor, tls_sanlist or [b"DNS:test.com"]
+ )
server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
@@ -618,8 +637,8 @@ class MatrixFederationAgentTests(TestCase):
request.finish()
# now we make another test server to act as the upstream HTTP server.
- server_ssl_protocol = _wrap_server_factory_for_tls(
- _get_test_protocol_factory()
+ server_ssl_protocol = wrap_server_factory_for_tls(
+ _get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
).buildProtocol(dummy_address)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
@@ -785,7 +804,9 @@ class MatrixFederationAgentTests(TestCase):
request.finish()
# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
- ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+ ssl_factory = wrap_server_factory_for_tls(
+ _get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
+ )
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
http_server = ssl_protocol.wrappedProtocol
@@ -849,30 +870,6 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
-def _wrap_server_factory_for_tls(
- factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> TLSMemoryBIOFactory:
- """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
-
- The resultant factory will create a TLS server which presents a certificate
- signed by our test CA, valid for the domains in `sanlist`
-
- Args:
- factory: protocol factory to wrap
- sanlist: list of domains the cert should be valid for
-
- Returns:
- interfaces.IProtocolFactory
- """
- if sanlist is None:
- sanlist = [b"DNS:test.com"]
-
- connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
- return TLSMemoryBIOFactory(
- connection_creator, isClient=False, wrappedFactory=factory
- )
-
-
def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel
diff --git a/tests/module_api/test_event_unsigned_addition.py b/tests/module_api/test_event_unsigned_addition.py
new file mode 100644
index 0000000000..b64426b1ac
--- /dev/null
+++ b/tests/module_api/test_event_unsigned_addition.py
@@ -0,0 +1,59 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.events import EventBase
+from synapse.rest import admin, login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventUnsignedAdditionTestCase(HomeserverTestCase):
+ servlets = [
+ room.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self._store = homeserver.get_datastores().main
+ self._module_api = homeserver.get_module_api()
+ self._account_data_mgr = self._module_api.account_data_manager
+
+ def test_annotate_event(self) -> None:
+ """Test that we can annotate an event when we request it from the
+ server.
+ """
+
+ async def add_unsigned_event(event: EventBase) -> JsonDict:
+ return {"test_key": event.event_id}
+
+ self._module_api.register_add_extra_fields_to_unsigned_client_event_callbacks(
+ add_field_to_unsigned_callback=add_unsigned_event
+ )
+
+ user_id = self.register_user("user", "password")
+ token = self.login("user", "password")
+
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ result = self.helper.send(room_id, "Hello!", tok=token)
+ event_id = result["event_id"]
+
+ event_json = self.helper.get_event(room_id, event_id, tok=token)
+ self.assertEqual(event_json["unsigned"].get("test_key"), event_id)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 128fc3e046..b8ab4ee54b 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -14,6 +14,8 @@
from typing import Any, List, Optional
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
@@ -21,6 +23,8 @@ from synapse.events import EventBase
from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
+ _MAX_STATE_UPDATES_PER_ROOM,
+ EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
@@ -106,11 +110,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
- def test_update_function_huge_state_change(self) -> None:
+ @parameterized.expand(
+ [(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
+ )
+ def test_update_function_huge_state_change(
+ self, num_state_changes: int, collapse_state_changes: bool
+ ) -> None:
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.
+
+ Args:
+ num_state_changes: The number of state changes to create.
+ collapse_state_changes: Whether the state changes are expected to be
+ collapsed or not.
"""
# we want to generate lots of state changes at a single stream ID.
@@ -145,7 +159,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
events = [
self._inject_state_event(sender=OTHER_USER)
- for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
+ for _ in range(num_state_changes)
]
self.replicate()
@@ -202,8 +216,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
- # first check the first two rows, which should be state1
-
+ # first check the first two rows, which should be the state1 event.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
@@ -217,7 +230,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
- # now the last two rows, which should be state2
+ # now the last two rows, which should be the state2 event.
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
@@ -231,34 +244,54 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)
- # that should leave us with the rows for the PL event
- self.assertEqual(len(received_rows), len(events) + 2)
+ # Based on the number of
+ if collapse_state_changes:
+ # that should leave us with the rows for the PL event, the state changes
+ # get collapsed into a single row.
+ self.assertEqual(len(received_rows), 2)
- stream_name, token, row = received_rows.pop(0)
- self.assertEqual("events", stream_name)
- self.assertIsInstance(row, EventsStreamRow)
- self.assertEqual(row.type, "ev")
- self.assertIsInstance(row.data, EventsStreamEventRow)
- self.assertEqual(row.data.event_id, pl_event.event_id)
+ stream_name, token, row = received_rows.pop(0)
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_event.event_id)
- # the state rows are unsorted
- state_rows: List[EventsStreamCurrentStateRow] = []
- for stream_name, _, row in received_rows:
+ stream_name, token, row = received_rows.pop(0)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state-all")
+ self.assertIsInstance(row.data, EventsStreamAllStateRow)
+ self.assertEqual(row.data.room_id, state2.room_id)
+
+ else:
+ # that should leave us with the rows for the PL event
+ self.assertEqual(len(received_rows), len(events) + 2)
+
+ stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
- self.assertEqual(row.type, "state")
- self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
- state_rows.append(row.data)
-
- state_rows.sort(key=lambda r: r.state_key)
-
- sr = state_rows.pop(0)
- self.assertEqual(sr.type, EventTypes.PowerLevels)
- self.assertEqual(sr.event_id, pl_event.event_id)
- for sr in state_rows:
- self.assertEqual(sr.type, "test_state_event")
- # "None" indicates the state has been deleted
- self.assertIsNone(sr.event_id)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, pl_event.event_id)
+
+ # the state rows are unsorted
+ state_rows: List[EventsStreamCurrentStateRow] = []
+ for stream_name, _, row in received_rows:
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "state")
+ self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
+ state_rows.append(row.data)
+
+ state_rows.sort(key=lambda r: r.state_key)
+
+ sr = state_rows.pop(0)
+ self.assertEqual(sr.type, EventTypes.PowerLevels)
+ self.assertEqual(sr.event_id, pl_event.event_id)
+ for sr in state_rows:
+ self.assertEqual(sr.type, "test_state_event")
+ # "None" indicates the state has been deleted
+ self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids."""
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index b230a6c361..1e9994cc0b 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,9 +15,7 @@ import logging
import os
from typing import Any, Optional, Tuple
-from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
-from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
@@ -27,7 +25,11 @@ from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
-from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.http import (
+ TestServerTLSConnectionFactory,
+ get_test_ca_cert_file,
+ wrap_server_factory_for_tls,
+)
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
@@ -94,7 +96,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
# build the test server
- server_tls_protocol = _build_test_server(get_connection_factory())
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_protocol = wrap_server_factory_for_tls(
+ server_factory, self.reactor, sanlist=[b"DNS:example.com"]
+ ).buildProtocol(None)
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
@@ -114,7 +122,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
)
# fish the test server back out of the server-side TLS protocol.
- http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment]
+ http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
@@ -240,40 +248,6 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path))
-def get_connection_factory() -> TestServerTLSConnectionFactory:
- # this needs to happen once, but not until we are ready to run the first test
- global test_server_connection_factory
- if test_server_connection_factory is None:
- test_server_connection_factory = TestServerTLSConnectionFactory(
- sanlist=[b"DNS:example.com"]
- )
- return test_server_connection_factory
-
-
-def _build_test_server(
- connection_creator: IOpenSSLServerConnectionCreator,
-) -> TLSMemoryBIOProtocol:
- """Construct a test server
-
- This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
-
- Args:
- connection_creator: thing to build SSL connections
-
- Returns:
- TLSMemoryBIOProtocol
- """
- server_factory = Factory.forProtocol(HTTPChannel)
- # Request.finish expects the factory to have a 'log' method.
- server_factory.log = _log_request
-
- server_tls_factory = TLSMemoryBIOFactory(
- connection_creator, isClient=False, wrappedFactory=server_factory
- )
-
- return server_tls_factory.buildProtocol(None)
-
-
def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
diff --git a/tests/replication/test_sharded_receipts.py b/tests/replication/test_sharded_receipts.py
new file mode 100644
index 0000000000..41876b36de
--- /dev/null
+++ b/tests/replication/test_sharded_receipts.py
@@ -0,0 +1,243 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import ReceiptTypes
+from synapse.rest import admin
+from synapse.rest.client import login, receipts, room, sync
+from synapse.server import HomeServer
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.types import StreamToken
+from synapse.util import Clock
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptsShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks receipts sharding works"""
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ receipts.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ self.room_creator = self.hs.get_room_creation_handler()
+ self.store = hs.get_datastores().main
+
+ def default_config(self) -> dict:
+ conf = super().default_config()
+ conf["stream_writers"] = {"receipts": ["worker1", "worker2"]}
+ conf["instance_map"] = {
+ "main": {"host": "testserv", "port": 8765},
+ "worker1": {"host": "testserv", "port": 1001},
+ "worker2": {"host": "testserv", "port": 1002},
+ }
+ return conf
+
+ def test_basic(self) -> None:
+ """Simple test to ensure that receipts can be sent on multiple
+ workers.
+ """
+
+ worker1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "worker1"},
+ )
+ worker1_site = self._hs_to_site[worker1]
+
+ worker2 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "worker2"},
+ )
+ worker2_site = self._hs_to_site[worker2]
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Create a room
+ room_id = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room_id, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # First user sends a message, the other users sends a receipt.
+ response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+ event_id = response["event_id"]
+
+ channel = make_request(
+ reactor=self.reactor,
+ site=worker1_site,
+ method="POST",
+ path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
+ access_token=access_token,
+ content={},
+ )
+ self.assertEqual(200, channel.code)
+
+ # Now we do it again using the second worker
+ response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+ event_id = response["event_id"]
+
+ channel = make_request(
+ reactor=self.reactor,
+ site=worker2_site,
+ method="POST",
+ path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
+ access_token=access_token,
+ content={},
+ )
+ self.assertEqual(200, channel.code)
+
+ def test_vector_clock_token(self) -> None:
+ """Tests that using a stream token with a vector clock component works
+ correctly with basic /sync usage.
+ """
+
+ worker_hs1 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "worker1"},
+ )
+ worker1_site = self._hs_to_site[worker_hs1]
+
+ worker_hs2 = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "worker2"},
+ )
+ worker2_site = self._hs_to_site[worker_hs2]
+
+ sync_hs = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "sync"},
+ )
+ sync_hs_site = self._hs_to_site[sync_hs]
+
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ store = self.hs.get_datastores().main
+
+ room_id = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room_id, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+ first_event = response["event_id"]
+
+ # Do an initial sync so that we're up to date.
+ channel = make_request(
+ self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
+ )
+ next_batch = channel.json_body["next_batch"]
+
+ # We now gut wrench into the events stream MultiWriterIdGenerator on
+ # worker2 to mimic it getting stuck persisting a receipt. This ensures
+ # that when we send an event on worker1 we end up in a state where
+ # worker2 events stream position lags that on worker1, resulting in a
+ # receipts token with a non-empty instance map component.
+ #
+ # Worker2's receipts stream position will not advance until we call
+ # __aexit__ again.
+ worker_store2 = worker_hs2.get_datastores().main
+ assert isinstance(worker_store2._receipts_id_gen, MultiWriterIdGenerator)
+
+ actx = worker_store2._receipts_id_gen.get_next()
+ self.get_success(actx.__aenter__())
+
+ channel = make_request(
+ reactor=self.reactor,
+ site=worker1_site,
+ method="POST",
+ path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{first_event}",
+ access_token=access_token,
+ content={},
+ )
+ self.assertEqual(200, channel.code)
+
+ # Assert that the current stream token has an instance map component, as
+ # we are trying to test vector clock tokens.
+ receipts_token = store.get_max_receipt_stream_id()
+ self.assertGreater(len(receipts_token.instance_map), 0)
+
+ # Check that syncing still gets the new receipt, despite the gap in the
+ # stream IDs.
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ f"/sync?since={next_batch}",
+ access_token=access_token,
+ )
+
+ # We should only see the new event and nothing else
+ self.assertIn(room_id, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"]
+ self.assertEqual(len(events), 1)
+ self.assertIn(first_event, events[0]["content"])
+
+ # Get the next batch and makes sure its a vector clock style token.
+ vector_clock_token = channel.json_body["next_batch"]
+ parsed_token = self.get_success(
+ StreamToken.from_string(store, vector_clock_token)
+ )
+ self.assertGreaterEqual(len(parsed_token.receipt_key.instance_map), 1)
+
+ # Now that we've got a vector clock token we finish the fake persisting
+ # a receipt we started above.
+ self.get_success(actx.__aexit__(None, None, None))
+
+ # Now try and send another receipts to the other worker.
+ response = self.helper.send(room_id, body="Hi!", tok=self.other_access_token)
+ second_event = response["event_id"]
+
+ channel = make_request(
+ reactor=self.reactor,
+ site=worker2_site,
+ method="POST",
+ path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{second_event}",
+ access_token=access_token,
+ content={},
+ )
+
+ channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ f"/sync?since={vector_clock_token}",
+ access_token=access_token,
+ )
+
+ self.assertIn(room_id, channel.json_body["rooms"]["join"])
+
+ events = channel.json_body["rooms"]["join"][room_id]["ephemeral"]["events"]
+ self.assertEqual(len(events), 1)
+ self.assertIn(second_event, events[0]["content"])
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 66b387cea3..4e89107e54 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -50,7 +50,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
PUT to the status endpoint with use_presence enabled will call
set_state on the presence handler.
"""
- self.hs.config.server.use_presence = True
+ self.hs.config.server.presence_enabled = True
body = {"presence": "here", "status_msg": "beep boop"}
channel = self.make_request(
@@ -63,7 +63,22 @@ class PresenceTestCase(unittest.HomeserverTestCase):
@unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self) -> None:
"""
- PUT to the status endpoint with use_presence disabled will NOT call
+ PUT to the status endpoint with presence disabled will NOT call
+ set_state on the presence handler.
+ """
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(self.presence_handler.set_state.call_count, 0)
+
+ @unittest.override_config({"presence": {"enabled": "untracked"}})
+ def test_put_presence_untracked(self) -> None:
+ """
+ PUT to the status endpoint with presence untracked will NOT call
set_state on the presence handler.
"""
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index d3e06bf6b3..534dc339f3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -243,7 +243,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
assert event is not None
time_now = self.clock.time_msec()
- serialized = self.serializer.serialize_event(event, time_now)
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
return serialized
diff --git a/tests/server.py b/tests/server.py
index 08633fe640..cfb0fb823b 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -43,9 +43,11 @@ from typing import (
from unittest.mock import Mock
import attr
+from incremental import Version
from typing_extensions import ParamSpec
from zope.interface import implementer
+import twisted
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
@@ -474,6 +476,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
+ # In order for the TLS protocol tests to work, modify _get_default_clock
+ # on newer Twisted versions to use the test reactor's clock.
+ #
+ # This is *super* dirty since it is never undone and relies on the next
+ # test to overwrite it.
+ if twisted.version > Version("Twisted", 23, 8, 0):
+ from twisted.protocols import tls
+
+ tls._get_default_clock = lambda: self # type: ignore[attr-defined]
+
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index 71db47405e..98b01086bc 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -117,7 +117,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None:
columns += expected_row.keys()
- rows = self.get_success(
+ row_tuples = self.get_success(
self.store.db_pool.simple_select_list(
table=table,
keyvalues={
@@ -134,22 +134,22 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
if expected_row is not None:
self.assertEqual(
- len(rows),
+ len(row_tuples),
1,
f"Background update did not leave behind latest receipt in {table}",
)
self.assertEqual(
- rows[0],
- {
- "room_id": room_id,
- "receipt_type": receipt_type,
- "user_id": user_id,
- **expected_row,
- },
+ row_tuples[0],
+ (
+ room_id,
+ receipt_type,
+ user_id,
+ *expected_row.values(),
+ ),
)
else:
self.assertEqual(
- len(rows),
+ len(row_tuples),
0,
f"Background update did not remove all duplicate receipts from {table}",
)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 8bbf936ae9..8cbc974ac4 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import secrets
-from typing import Generator, Tuple
+from typing import Generator, List, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
@@ -47,15 +47,15 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
)
def _dump_table_to_tuple(self) -> Generator[Tuple[int, str, str], None, None]:
- res = self.get_success(
- self.storage.db_pool.simple_select_list(
- self.table_name, None, ["id, username, value"]
- )
+ yield from cast(
+ List[Tuple[int, str, str]],
+ self.get_success(
+ self.storage.db_pool.simple_select_list(
+ self.table_name, None, ["id, username, value"]
+ )
+ ),
)
- for i in res:
- yield (i["id"], i["username"], i["value"])
-
def test_upsert_many(self) -> None:
"""
Upsert_many will perform the upsert operation across a batch of data.
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index abf7d0564d..67ea640902 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple, cast
from unittest.mock import AsyncMock, Mock
import yaml
@@ -456,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
);
"""
self.get_success(
- self.store.db_pool.execute(
- "test_not_null_constraint", lambda _: None, table_sql
+ self.store.db_pool.runInteraction(
+ "test_not_null_constraint", lambda txn: txn.execute(table_sql)
)
)
@@ -465,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
# using SQLite.
index_sql = "CREATE INDEX test_index ON test_constraint(a)"
self.get_success(
- self.store.db_pool.execute(
- "test_not_null_constraint", lambda _: None, index_sql
+ self.store.db_pool.runInteraction(
+ "test_not_null_constraint", lambda txn: txn.execute(index_sql)
)
)
@@ -526,15 +527,18 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates()
# Check the correct values are in the new table.
- rows = self.get_success(
- self.store.db_pool.simple_select_list(
- table="test_constraint",
- keyvalues={},
- retcols=("a", "b"),
- )
+ rows = cast(
+ List[Tuple[int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="test_constraint",
+ keyvalues={},
+ retcols=("a", "b"),
+ )
+ ),
)
- self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+ self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected.
self.get_failure(
@@ -570,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
);
"""
self.get_success(
- self.store.db_pool.execute(
- "test_foreign_key_constraint", lambda _: None, base_sql
+ self.store.db_pool.runInteraction(
+ "test_foreign_key_constraint", lambda txn: txn.execute(base_sql)
)
)
self.get_success(
- self.store.db_pool.execute(
- "test_foreign_key_constraint", lambda _: None, table_sql
+ self.store.db_pool.runInteraction(
+ "test_foreign_key_constraint", lambda txn: txn.execute(table_sql)
)
)
@@ -640,14 +644,17 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
self.wait_for_background_updates()
# Check the correct values are in the new table.
- rows = self.get_success(
- self.store.db_pool.simple_select_list(
- table="test_constraint",
- keyvalues={},
- retcols=("a", "b"),
- )
+ rows = cast(
+ List[Tuple[int, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="test_constraint",
+ keyvalues={},
+ retcols=("a", "b"),
+ )
+ ),
)
- self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
+ self.assertCountEqual(rows, [(1, 1), (3, 3)])
# And check that invalid rows get correctly rejected.
self.get_failure(
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 256d28e4c9..e4a52c301e 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -146,7 +146,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
- self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
+ self.mock_txn.fetchall.return_value = [(1,), (2,), (3,)]
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
ret = yield defer.ensureDeferred(
@@ -155,7 +155,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
)
- self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
+ self.assertEqual([(1,), (2,), (3,)], ret)
self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 0c054a598f..8e4393d843 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Dict, List, Optional, Tuple, cast
from unittest.mock import AsyncMock
from parameterized import parameterized
@@ -97,26 +97,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
self.pump(0)
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
self.assertEqual(
- result,
- [
- {
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": None,
- "last_seen": 12345678000,
- }
- ],
+ result, [("access_token", "ip", "user_agent", None, 12345678000)]
)
# Add another & trigger the storage loop
@@ -128,26 +128,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
self.pump(0)
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
# Only one result, has been upserted.
self.assertEqual(
- result,
- [
- {
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": None,
- "last_seen": 12345878000,
- }
- ],
+ result, [("access_token", "ip", "user_agent", None, 12345878000)]
)
@parameterized.expand([(False,), (True,)])
@@ -177,25 +177,23 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
else:
# Check that the new IP and user agent has not been stored yet
- db_result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="devices",
- keyvalues={},
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ db_result = cast(
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ),
+ ),
),
)
- self.assertEqual(
- db_result,
- [
- {
- "user_id": user_id,
- "device_id": device_id,
- "ip": None,
- "user_agent": None,
- "last_seen": None,
- },
- ],
- )
+ self.assertEqual(db_result, [(user_id, None, None, device_id, None)])
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
@@ -261,30 +259,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# Check that the new IP and user agent has not been stored yet
- db_result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="devices",
- keyvalues={},
- retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ db_result = cast(
+ List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="devices",
+ keyvalues={},
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+ ),
),
)
self.assertCountEqual(
db_result,
[
- {
- "user_id": user_id,
- "device_id": device_id_1,
- "ip": "ip_1",
- "user_agent": "user_agent_1",
- "last_seen": 12345678000,
- },
- {
- "user_id": user_id,
- "device_id": device_id_2,
- "ip": "ip_2",
- "user_agent": "user_agent_2",
- "last_seen": 12345678000,
- },
+ (user_id, "ip_1", "user_agent_1", device_id_1, 12345678000),
+ (user_id, "ip_2", "user_agent_2", device_id_2, 12345678000),
],
)
@@ -385,28 +374,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# Check that the new IP and user agent has not been stored yet
- db_result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={},
- retcols=("access_token", "ip", "user_agent", "last_seen"),
+ db_result = cast(
+ List[Tuple[str, str, str, int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=("access_token", "ip", "user_agent", "last_seen"),
+ ),
),
)
self.assertEqual(
db_result,
[
- {
- "access_token": "access_token",
- "ip": "ip_1",
- "user_agent": "user_agent_1",
- "last_seen": 12345678000,
- },
- {
- "access_token": "access_token",
- "ip": "ip_2",
- "user_agent": "user_agent_2",
- "last_seen": 12345678000,
- },
+ ("access_token", "ip_1", "user_agent_1", 12345678000),
+ ("access_token", "ip_2", "user_agent_2", 12345678000),
],
)
@@ -600,39 +582,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
# We should see that in the DB
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
self.assertEqual(
result,
- [
- {
- "access_token": "access_token",
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": device_id,
- "last_seen": 0,
- }
- ],
+ [("access_token", "ip", "user_agent", device_id, 0)],
)
# Now advance by a couple of months
self.reactor.advance(60 * 24 * 60 * 60)
# We should get no results.
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={"user_id": user_id},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
self.assertEqual(result, [])
@@ -696,28 +688,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(200)
# We should see that in the DB
- result = self.get_success(
- self.store.db_pool.simple_select_list(
- table="user_ips",
- keyvalues={},
- retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
- desc="get_user_ip_and_agents",
- )
+ result = cast(
+ List[Tuple[str, str, str, Optional[str], int]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="user_ips",
+ keyvalues={},
+ retcols=[
+ "access_token",
+ "ip",
+ "user_agent",
+ "device_id",
+ "last_seen",
+ ],
+ desc="get_user_ip_and_agents",
+ )
+ ),
)
# ensure user1 is filtered out
- self.assertEqual(
- result,
- [
- {
- "access_token": access_token2,
- "ip": "ip",
- "user_agent": "user_agent",
- "device_id": device_id2,
- "last_seen": 0,
- }
- ],
- )
+ self.assertEqual(result, [(access_token2, "ip", "user_agent", device_id2, 0)])
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9174fb0964..fd53b0644c 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
- # The table is empty so we expect an empty map for positions
- self.assertEqual(id_gen.get_positions(), {})
+ # The table is empty so we expect the map for positions to have a dummy
+ # minimum value.
+ self.assertEqual(id_gen.get_positions(), {"master": 1})
def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
@@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- # The first ID gen will notice that it can advance its token to 7 as it
- # has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
- # ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
@@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+ def test_multi_instance_empty_row(self) -> None:
+ """Test that reads and writes from multiple processes are handled
+ correctly, when one of the writers starts without any rows.
+ """
+ # Insert some rows for two out of three of the ID gens.
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator(
+ "first", writers=["first", "second", "third"]
+ )
+ second_id_gen = self._create_id_generator(
+ "second", writers=["first", "second", "third"]
+ )
+ third_id_gen = self._create_id_generator(
+ "third", writers=["first", "second", "third"]
+ )
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async() -> None:
+ async with third_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
+ )
+ self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(
+ third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
+ )
+
def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly."""
@@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
with self.assertRaises(IncorrectDatabaseSetup):
self._create_id_generator("first")
+ def test_minimal_local_token(self) -> None:
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
+
+ def test_current_token_gap(self) -> None:
+ """Test that getting the current token for a writer returns the maximal
+ token when there are no writes.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator(
+ "first", writers=["first", "second", "third"]
+ )
+ second_id_gen = self._create_id_generator(
+ "second", writers=["first", "second", "third"]
+ )
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ # Check that the first ID gen advancing causes the second ID gen to
+ # advance (as the second ID gen has nothing in flight).
+
+ async def _get_next_async() -> None:
+ async with first_id_gen.get_next_mult(2):
+ pass
+
+ self.get_success(_get_next_async())
+ second_id_gen.advance("first", 9)
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ # Check that the first ID gen advancing doesn't advance the second ID
+ # gen when the second ID gen has stuff in flight.
+ self.get_success(_get_next_async())
+
+ ctxmgr = second_id_gen.get_next()
+ self.get_success(ctxmgr.__aenter__())
+
+ second_id_gen.advance("first", 11)
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
+ self.get_success(ctxmgr.__aexit__(None, None, None))
+
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
+ self.assertEqual(second_id_gen.get_current_token(), 7)
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async())
- self.assertEqual(id_gen_1.get_positions(), {"first": -1})
- self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
@@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
- self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
- self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 95f99f4130..6afb5403bd 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.store.db_pool.execute(
- "", None, "SELECT full_user_id from profiles ORDER BY full_user_id"
+ "", "SELECT full_user_id from profiles ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f4c4661aaf..36fcab06b5 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional, Tuple, cast
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import Membership
@@ -110,21 +112,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- res = self.get_success(
- self.store.db_pool.simple_select_list(
- "room_memberships",
- {"user_id": "@alice:test"},
- ["display_name", "event_id"],
- )
+ res = cast(
+ List[Tuple[Optional[str], str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ ),
)
# Check that we only got one result back
self.assertEqual(len(res), 1)
# Check that alice's display name is "alice"
- self.assertEqual(res[0]["display_name"], "alice")
+ self.assertEqual(res[0][0], "alice")
# Grab the event_id to use later
- event_id = res[0]["event_id"]
+ event_id = res[0][1]
# Create a profile with the offending null byte in the display name
new_profile = {"displayname": "ali\u0000ce"}
@@ -139,21 +144,24 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
tok=self.t_alice,
)
- res2 = self.get_success(
- self.store.db_pool.simple_select_list(
- "room_memberships",
- {"user_id": "@alice:test"},
- ["display_name", "event_id"],
- )
+ res2 = cast(
+ List[Tuple[Optional[str], str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "room_memberships",
+ {"user_id": "@alice:test"},
+ ["display_name", "event_id"],
+ )
+ ),
)
# Check that we only have two results
self.assertEqual(len(res2), 2)
# Filter out the previous event using the event_id we grabbed above
- row = [row for row in res2 if row["event_id"] != event_id]
+ row = [row for row in res2 if row[1] != event_id]
# Check that alice's display name is now None
- self.assertEqual(row[0]["display_name"], None)
+ self.assertIsNone(row[0][0])
def test_room_is_locally_forgotten(self) -> None:
"""Test that when the last local user has forgotten a room it is known as forgotten."""
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b9446c36c..2715c73f16 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from typing import List, Tuple, cast
from immutabledict import immutabledict
@@ -584,18 +585,21 @@ class StateStoreTestCase(HomeserverTestCase):
)
# check that only state events are in state_groups, and all state events are in state_groups
- res = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_groups",
- keyvalues=None,
- retcols=("event_id",),
- )
+ res = cast(
+ List[Tuple[str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups",
+ keyvalues=None,
+ retcols=("event_id",),
+ )
+ ),
)
events = []
for result in res:
- self.assertNotIn(event3.event_id, result)
- events.append(result.get("event_id"))
+ self.assertNotIn(event3.event_id, result) # XXX
+ events.append(result[0])
for event, _ in processed_events_and_context:
if event.is_state():
@@ -606,23 +610,29 @@ class StateStoreTestCase(HomeserverTestCase):
# has an entry and prev event in state_group_edges
for event, context in processed_events_and_context:
if event.is_state():
- state = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_groups_state",
- keyvalues={"state_group": context.state_group_after_event},
- retcols=("type", "state_key"),
- )
- )
- self.assertEqual(event.type, state[0].get("type"))
- self.assertEqual(event.state_key, state[0].get("state_key"))
-
- groups = self.get_success(
- self.store.db_pool.simple_select_list(
- table="state_group_edges",
- keyvalues={"state_group": str(context.state_group_after_event)},
- retcols=("*",),
- )
+ state = cast(
+ List[Tuple[str, str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups_state",
+ keyvalues={"state_group": context.state_group_after_event},
+ retcols=("type", "state_key"),
+ )
+ ),
)
- self.assertEqual(
- context.state_group_before_event, groups[0].get("prev_state_group")
+ self.assertEqual(event.type, state[0][0])
+ self.assertEqual(event.state_key, state[0][1])
+
+ groups = cast(
+ List[Tuple[str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_group_edges",
+ keyvalues={
+ "state_group": str(context.state_group_after_event)
+ },
+ retcols=("prev_state_group",),
+ )
+ ),
)
+ self.assertEqual(context.state_group_before_event, groups[0][0])
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 8c72aa1722..822c41dd9f 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
-from typing import Any, Dict, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple, cast
from unittest import mock
from unittest.mock import Mock, patch
@@ -62,14 +62,13 @@ class GetUserDirectoryTables:
Returns a list of tuples (user_id, room_id) where room_id is public and
contains the user with the given id.
"""
- r = await self.store.db_pool.simple_select_list(
- "users_in_public_rooms", None, ("user_id", "room_id")
+ r = cast(
+ List[Tuple[str, str]],
+ await self.store.db_pool.simple_select_list(
+ "users_in_public_rooms", None, ("user_id", "room_id")
+ ),
)
-
- retval = set()
- for i in r:
- retval.add((i["user_id"], i["room_id"]))
- return retval
+ return set(r)
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
"""Fetch the entire `users_who_share_private_rooms` table.
@@ -78,27 +77,30 @@ class GetUserDirectoryTables:
to the rows of `users_who_share_private_rooms`.
"""
- rows = await self.store.db_pool.simple_select_list(
- "users_who_share_private_rooms",
- None,
- ["user_id", "other_user_id", "room_id"],
+ rows = cast(
+ List[Tuple[str, str, str]],
+ await self.store.db_pool.simple_select_list(
+ "users_who_share_private_rooms",
+ None,
+ ["user_id", "other_user_id", "room_id"],
+ ),
)
- rv = set()
- for row in rows:
- rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
- return rv
+ return set(rows)
async def get_users_in_user_directory(self) -> Set[str]:
"""Fetch the set of users in the `user_directory` table.
This is useful when checking we've correctly excluded users from the directory.
"""
- result = await self.store.db_pool.simple_select_list(
- "user_directory",
- None,
- ["user_id"],
+ result = cast(
+ List[Tuple[str]],
+ await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ["user_id"],
+ ),
)
- return {row["user_id"] for row in result}
+ return {row[0] for row in result}
async def get_profiles_in_user_directory(self) -> Dict[str, ProfileInfo]:
"""Fetch users and their profiles from the `user_directory` table.
@@ -107,16 +109,17 @@ class GetUserDirectoryTables:
It's almost the entire contents of the `user_directory` table: the only
thing missing is an unused room_id column.
"""
- rows = await self.store.db_pool.simple_select_list(
- "user_directory",
- None,
- ("user_id", "display_name", "avatar_url"),
+ rows = cast(
+ List[Tuple[str, Optional[str], Optional[str]]],
+ await self.store.db_pool.simple_select_list(
+ "user_directory",
+ None,
+ ("user_id", "display_name", "avatar_url"),
+ ),
)
return {
- row["user_id"]: ProfileInfo(
- display_name=row["display_name"], avatar_url=row["avatar_url"]
- )
- for row in rows
+ user_id: ProfileInfo(display_name=display_name, avatar_url=avatar_url)
+ for user_id, display_name, avatar_url in rows
}
async def get_tables(
diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py
index d4637d9d1e..2da6a018e8 100644
--- a/tests/storage/test_user_filters.py
+++ b/tests/storage/test_user_filters.py
@@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.store.db_pool.execute(
- "", None, "SELECT full_user_id from user_filters ORDER BY full_user_id"
+ "", "SELECT full_user_id from user_filters ORDER BY full_user_id"
)
)
self.assertEqual(len(res), len(expected_values))
diff --git a/tests/unittest.py b/tests/unittest.py
index 99ad02eb06..79c47fc3cc 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -30,6 +30,7 @@ from typing import (
Generic,
Iterable,
List,
+ Mapping,
NoReturn,
Optional,
Tuple,
@@ -251,7 +252,7 @@ class TestCase(unittest.TestCase):
except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e
- def assert_dict(self, required: dict, actual: dict) -> None:
+ def assert_dict(self, required: Mapping, actual: Mapping) -> None:
"""Does a partial assert of a dict.
Args:
|