diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 69945a8f98..eb78ab412a 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# another lookup.
self.reactor.pump((900.0,))
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
# The resolver may retry a few times, so fonx all requests that come along
attempts = 0
@@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails.
- fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+ fetch_d = defer.ensureDeferred(
+ self.well_known_resolver.get_well_known(b"testserv")
+ )
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0b5204654c..561258a356 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 0},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 1},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "notify_count": 2},
+ {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success(
self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ event.event_id,
+ {user_id: actions for user_id, actions in push_actions},
+ False,
)
)
return event, context
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
new file mode 100644
index 0000000000..5ae72fd008
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import shared_rooms
+
+from tests import unittest
+
+
+class UserSharedRoomsTest(unittest.HomeserverTestCase):
+ """
+ Tests the UserSharedRoomsServlet.
+ """
+
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ shared_rooms.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["update_user_directory"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def _get_shared_rooms(self, token, other_user):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ % other_user,
+ access_token=token,
+ )
+ self.render(request)
+ return request, channel
+
+ def test_shared_room_list_public(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is public.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_private(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is private.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_mixed(self):
+ """
+ The shared room list between two users should contain both public and private
+ rooms.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
+ self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
+ self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
+ self.helper.join(room_public, user=u2, tok=u2_token)
+ self.helper.join(room_private, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 2)
+ self.assertTrue(room_public in channel.json_body["joined"])
+ self.assertTrue(room_private in channel.json_body["joined"])
+
+ def test_shared_room_list_after_leave(self):
+ """
+ A room should no longer be considered shared if the other
+ user has left it.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Assert user directory is not empty
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ self.helper.leave(room, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u2_token, u1)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ read_marker.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.url = "/sync?since=%s"
+ self.next_batch = "s0"
+
+ # Register the first user (used to check the unread counts).
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ # Create the room we'll check unread counts for.
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ # Register the second user (used to send events to the room).
+ self.user2 = self.register_user("kermit2", "monkey")
+ self.tok2 = self.login("kermit2", "monkey")
+
+ # Change the power levels of the room so that the second user can send state
+ # events.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.PowerLevels,
+ {
+ "users": {self.user_id: 100, self.user2: 100},
+ "users_default": 0,
+ "events": {
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.history_visibility": 100,
+ "m.room.canonical_alias": 50,
+ "m.room.avatar": 50,
+ "m.room.tombstone": 100,
+ "m.room.server_acl": 100,
+ "m.room.encryption": 100,
+ },
+ "events_default": 0,
+ "state_default": 50,
+ "ban": 50,
+ "kick": 50,
+ "redact": 50,
+ "invite": 0,
+ },
+ tok=self.tok,
+ )
+
+ def test_unread_counts(self):
+ """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+ # Check that our own messages don't increase the unread count.
+ self.helper.send(self.room_id, "hello", tok=self.tok)
+ self._check_unread_count(0)
+
+ # Join the new user and check that this doesn't increase the unread count.
+ self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+ self._check_unread_count(0)
+
+ # Check that the new user sending a message increases our unread count.
+ res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+ self._check_unread_count(1)
+
+ # Send a read receipt to tell the server we've read the latest event.
+ body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/read_markers" % self.room_id,
+ body,
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that the unread counter is back to 0.
+ self._check_unread_count(0)
+
+ # Check that room name changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+ )
+ self._check_unread_count(1)
+
+ # Check that room topic changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+ )
+ self._check_unread_count(2)
+
+ # Check that encrypted messages increase the unread counter.
+ self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+ self._check_unread_count(3)
+
+ # Check that custom events with a body increase the unread counter.
+ self.helper.send_event(
+ self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that edits don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": "hello",
+ "msgtype": "m.text",
+ "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+ },
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that notices don't increase the unread counter.
+ self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"body": "hello", "msgtype": "m.notice"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(4)
+
+ # Check that tombstone events changes increase the unread counter.
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.Tombstone,
+ {"replacement_room": "!someroom:test"},
+ tok=self.tok2,
+ )
+ self._check_unread_count(5)
+
+ def _check_unread_count(self, expected_count: True):
+ """Syncs and compares the unread count with the expected value."""
+
+ request, channel = self.make_request(
+ "GET", self.url % self.next_batch, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ room_entry = channel.json_body["rooms"]["join"][self.room_id]
+ self.assertEqual(
+ room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+ )
+
+ # Store the next batch for the next request.
+ self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 261bf5b08b..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user", "device"),))
+ self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
- self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
+ self.store.get_e2e_device_keys_for_cs_api(
+ (("user1", "device1"), ("user2", "device2"))
+ )
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index cdfd2634aa..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
self.assertEquals(
counts,
- {"notify_count": noitf_count, "highlight_count": highlight_count},
+ {
+ "notify_count": noitf_count,
+ "unread_count": 0, # Unread counts are tested in the sync tests.
+ "highlight_count": highlight_count,
+ },
)
@defer.inlineCallbacks
@@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
- event.event_id, {user_id: action}
+ event.event_id, {user_id: action}, False,
)
)
yield defer.ensureDeferred(
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 14ce21c786..f0a8e32f1e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+ """
+
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ positive=False,
+ )
+
+ return self.get_success(self.db_pool.runWithConnection(_create))
+
+ def _insert_row(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+ id_gen = self._create_id_generator()
+
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -1})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -4})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+ # Test loading from DB by creating a second ID gen
+ second_id_gen = self._create_id_generator()
+
+ self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+ def test_multiple_instance(self):
+ """Tests that having multiple instances that get advanced over
+ federation works corretly.
+ """
+ id_gen_1 = self._create_id_generator("first")
+ id_gen_2 = self._create_id_generator("second")
+
+ with self.get_success(id_gen_1.get_next()) as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen_2.get_next()) as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8522c6fc09..fb1ca90336 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -13,14 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Collection
"""
Utility functions for poking events into the storage of the server under test.
@@ -58,7 +57,7 @@ async def inject_member_event(
async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
@@ -80,7 +79,7 @@ async def inject_event(
async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
if room_version is None:
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index bd32e2cee7..d3dea3b52a 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.async_helpers import ReadWriteLock
@@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
rwlock.read(key), # 5
rwlock.write(key), # 6
]
+ ds = [defer.ensureDeferred(d) for d in ds]
self._assert_called_before_not_after(ds, 2)
@@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
with ds[6].result:
pass
- d = rwlock.write(key)
+ d = defer.ensureDeferred(rwlock.write(key))
self.assertTrue(d.called)
with d.result:
pass
- d = rwlock.read(key)
+ d = defer.ensureDeferred(rwlock.read(key))
self.assertTrue(d.called)
with d.result:
pass
|