summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/tests/server.py b/tests/server.py
index 82990c2eb9..6ce2a17bf4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
     ITransport,
 )
 from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import (
+    AccumulatingProtocol,
+    MemoryReactor,
+    MemoryReactorClock,
+)
 from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.http.site import SynapseRequest
+from synapse.logging.context import ContextResourceUsage
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import PostgresEngine, create_engine
@@ -88,18 +93,19 @@ class TimedOutException(Exception):
     """
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class FakeChannel:
     """
     A fake Twisted Web Channel (the part that interfaces with the
     wire).
     """
 
-    site = attr.ib(type=Union[Site, "FakeSite"])
-    _reactor = attr.ib()
-    result = attr.ib(type=dict, default=attr.Factory(dict))
-    _ip = attr.ib(type=str, default="127.0.0.1")
+    site: Union[Site, "FakeSite"]
+    _reactor: MemoryReactor
+    result: dict = attr.Factory(dict)
+    _ip: str = "127.0.0.1"
     _producer: Optional[Union[IPullProducer, IPushProducer]] = None
+    resource_usage: Optional[ContextResourceUsage] = None
 
     @property
     def json_body(self):
@@ -168,6 +174,8 @@ class FakeChannel:
 
     def requestDone(self, _self):
         self.result["done"] = True
+        if isinstance(_self, SynapseRequest):
+            self.resource_usage = _self.logcontext.get_resource_usage()
 
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,