summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_server.py61
1 files changed, 60 insertions, 1 deletions
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index a10d017120..98af7aa675 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -15,7 +15,8 @@
 
 import yaml
 
-from synapse.config.server import ServerConfig, is_threepid_reserved
+from synapse.config._base import ConfigError
+from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved
 
 from tests import unittest
 
@@ -128,3 +129,61 @@ class ServerConfigTestCase(unittest.TestCase):
         )
 
         self.assertEqual(conf["listeners"], expected_listeners)
+
+
+class GenerateIpSetTestCase(unittest.TestCase):
+    def test_empty(self):
+        ip_set = generate_ip_set(())
+        self.assertFalse(ip_set)
+
+        ip_set = generate_ip_set((), ())
+        self.assertFalse(ip_set)
+
+    def test_generate(self):
+        """Check adding IPv4 and IPv6 addresses."""
+        # IPv4 address
+        ip_set = generate_ip_set(("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        # IPv4 CIDR
+        ip_set = generate_ip_set(("1.2.3.4/24",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        # IPv6 address
+        ip_set = generate_ip_set(("2001:db8::8a2e:370:7334",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+        # IPv6 CIDR
+        ip_set = generate_ip_set(("2001:db8::/104",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+        # The addresses can overlap OK.
+        ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+    def test_extra(self):
+        """Extra IP addresses are treated the same."""
+        ip_set = generate_ip_set((), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        ip_set = generate_ip_set(("1.1.1.1",), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 8)
+
+        # They can duplicate without error.
+        ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+    def test_bad_value(self):
+        """An error should be raised if a bad value is passed in."""
+        with self.assertRaises(ConfigError):
+            generate_ip_set(("not-an-ip",))
+
+        with self.assertRaises(ConfigError):
+            generate_ip_set(("1.2.3.4/128",))
+
+        with self.assertRaises(ConfigError):
+            generate_ip_set((":::",))
+
+        # The following get treated as empty data.
+        self.assertFalse(generate_ip_set(None))
+        self.assertFalse(generate_ip_set({}))