summary refs log tree commit diff
path: root/tests/test_utils
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_utils')
-rw-r--r--tests/test_utils/__init__.py39
-rw-r--r--tests/test_utils/html_parsers.py53
-rw-r--r--tests/test_utils/logging_setup.py2
3 files changed, 93 insertions, 1 deletions
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,13 @@ import warnings
 from asyncio import Future
 from typing import Any, Awaitable, Callable, TypeVar
 
+from mock import Mock
+
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
 TV = TypeVar("TV")
 
 
@@ -80,3 +87,35 @@ def setup_awaitable_errors() -> Callable[[], None]:
     sys.unraisablehook = unraisablehook  # type: ignore
 
     return cleanup
+
+
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+    # AsyncMock is not available in python3.5, this mimics part of its behaviour
+    async def cb(*args, **kwargs):
+        if raises:
+            raise raises
+        return return_value
+
+    return Mock(side_effect=cb)
+
+
+@attr.s
+class FakeResponse:
+    """A fake twisted.web.IResponse object
+
+    there is a similar class at treq.test.test_response, but it lacks a `phrase`
+    attribute, and didn't support deliverBody until recently.
+    """
+
+    # HTTP response code
+    code = attr.ib(type=int)
+
+    # HTTP response phrase (eg b'OK' for a 200)
+    phrase = attr.ib(type=bytes)
+
+    # body of the response
+    body = attr.ib(type=bytes)
+
+    def deliverBody(self, protocol):
+        protocol.dataReceived(self.body)
+        protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
new file mode 100644
index 0000000000..ad563eb3f0
--- /dev/null
+++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from html.parser import HTMLParser
+from typing import Dict, Iterable, List, Optional, Tuple
+
+
+class TestHtmlParser(HTMLParser):
+    """A generic HTML page parser which extracts useful things from the HTML"""
+
+    def __init__(self):
+        super().__init__()
+
+        # a list of links found in the doc
+        self.links = []  # type: List[str]
+
+        # the values of any hidden <input>s: map from name to value
+        self.hiddens = {}  # type: Dict[str, Optional[str]]
+
+        # the values of any radio buttons: map from name to list of values
+        self.radios = {}  # type: Dict[str, List[Optional[str]]]
+
+    def handle_starttag(
+        self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+    ) -> None:
+        attr_dict = dict(attrs)
+        if tag == "a":
+            href = attr_dict["href"]
+            if href:
+                self.links.append(href)
+        elif tag == "input":
+            input_name = attr_dict.get("name")
+            if attr_dict["type"] == "radio":
+                assert input_name
+                self.radios.setdefault(input_name, []).append(attr_dict["value"])
+            elif attr_dict["type"] == "hidden":
+                assert input_name
+                self.hiddens[input_name] = attr_dict["value"]
+
+    def error(_, message):
+        raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging():
     handler = ToTwistedHandler()
     formatter = logging.Formatter(log_format)
     handler.setFormatter(formatter)
-    handler.addFilter(LoggingContextFilter(request=""))
+    handler.addFilter(LoggingContextFilter())
     root_logger.addHandler(handler)
 
     log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")