diff --git a/tests/test_types.py b/tests/test_types.py
index 944aa784fc..00adc65a5a 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -19,9 +19,18 @@
#
#
+from typing import Type
+from unittest import skipUnless
+
+from immutabledict import immutabledict
+from parameterized import parameterized_class
+
from synapse.api.errors import SynapseError
from synapse.types import (
+ AbstractMultiWriterStreamToken,
+ MultiWriterStreamToken,
RoomAlias,
+ RoomStreamToken,
UserID,
get_domain_from_id,
get_localpart_from_id,
@@ -29,6 +38,7 @@ from synapse.types import (
)
from tests import unittest
+from tests.utils import USE_POSTGRES_FOR_TESTS
class IsMineIDTests(unittest.HomeserverTestCase):
@@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
+
+
+@parameterized_class(
+ ("token_type",),
+ [
+ (MultiWriterStreamToken,),
+ (RoomStreamToken,),
+ ],
+ class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
+)
+class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
+ """Tests for the different types of multi writer tokens."""
+
+ token_type: Type[AbstractMultiWriterStreamToken]
+
+ def test_basic_token(self) -> None:
+ """Test that a simple stream token can be serialized and unserialized"""
+ store = self.hs.get_datastores().main
+
+ token = self.token_type(stream=5)
+
+ string_token = self.get_success(token.to_string(store))
+
+ if isinstance(token, RoomStreamToken):
+ self.assertEqual(string_token, "s5")
+ else:
+ self.assertEqual(string_token, "5")
+
+ parsed_token = self.get_success(self.token_type.parse(store, string_token))
+ self.assertEqual(parsed_token, token)
+
+ @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
+ def test_instance_map(self) -> None:
+ """Test for stream token with instance map"""
+ store = self.hs.get_datastores().main
+
+ token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
+
+ string_token = self.get_success(token.to_string(store))
+ self.assertEqual(string_token, "m5~1.6")
+
+ parsed_token = self.get_success(self.token_type.parse(store, string_token))
+ self.assertEqual(parsed_token, token)
+
+ def test_instance_map_assertion(self) -> None:
+ """Test that we assert values in the instance map are greater than the
+ min stream position"""
+
+ with self.assertRaises(ValueError):
+ self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
+
+ with self.assertRaises(ValueError):
+ self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
+
+ def test_parse_bad_token(self) -> None:
+ """Test that we can parse tokens produced by a bug in Synapse of the
+ form `m5~`"""
+ store = self.hs.get_datastores().main
+
+ parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
+ self.assertEqual(parsed_token, self.token_type(stream=5))
|