diff options
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r-- | tests/replication/_base.py | 54 |
1 files changed, 42 insertions, 12 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a7602b4c96..970d5e533b 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol @@ -32,6 +33,7 @@ from synapse.server import HomeServer from tests import unittest from tests.server import FakeTransport +from tests.utils import USE_POSTGRES_FOR_TESTS try: import hiredis @@ -475,22 +477,25 @@ class FakeRedisPubSubServer: """A fake Redis server for pub/sub.""" def __init__(self): - self._subscribers = set() + self._subscribers_by_channel: Dict[ + bytes, Set["FakeRedisPubSubProtocol"] + ] = defaultdict(set) - def add_subscriber(self, conn): + def add_subscriber(self, conn, channel: bytes): """A connection has called SUBSCRIBE""" - self._subscribers.add(conn) + self._subscribers_by_channel[channel].add(conn) def remove_subscriber(self, conn): - """A connection has called UNSUBSCRIBE""" - self._subscribers.discard(conn) + """A connection has lost connection""" + for subscribers in self._subscribers_by_channel.values(): + subscribers.discard(conn) - def publish(self, conn, channel, msg) -> int: + def publish(self, conn, channel: bytes, msg) -> int: """A connection want to publish a message to subscribers.""" - for sub in self._subscribers: + for sub in self._subscribers_by_channel[channel]: sub.send(["message", channel, msg]) - return len(self._subscribers) + return len(self._subscribers_by_channel) def buildProtocol(self, addr): return FakeRedisPubSubProtocol(self) @@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol): num_subscribers = self._server.publish(self, channel, message) self.send(num_subscribers) elif command == b"SUBSCRIBE": - (channel,) = args - self._server.add_subscriber(self) - self.send(["subscribe", channel, 1]) + for idx, channel in enumerate(args): + num_channels = idx + 1 + self._server.add_subscriber(self, channel) + self.send(["subscribe", channel, num_channels]) # Since we use SET/GET to cache things we can safely no-op them. elif command == b"SET": @@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol): def connectionLost(self, reason): self._server.remove_subscriber(self) + + +class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase): + """ + A test case that enables Redis, providing a fake Redis server. + """ + + if not hiredis: + skip = "Requires hiredis" + + if not USE_POSTGRES_FOR_TESTS: + # Redis replication only takes place on Postgres + skip = "Requires Postgres" + + def default_config(self) -> Dict[str, Any]: + """ + Overrides the default config to enable Redis. + Even if the test only uses make_worker_hs, the main process needs Redis + enabled otherwise it won't create a Fake Redis server to listen on the + Redis port and accept fake TCP connections. + """ + base = super().default_config() + base["redis"] = {"enabled": True} + return base |