diff --git a/changelog.d/8587.misc b/changelog.d/8587.misc
new file mode 100644
index 0000000000..9e56551a34
--- /dev/null
+++ b/changelog.d/8587.misc
@@ -0,0 +1 @@
+Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting.
diff --git a/changelog.d/8600.misc b/changelog.d/8600.misc
new file mode 100644
index 0000000000..a5a922e641
--- /dev/null
+++ b/changelog.d/8600.misc
@@ -0,0 +1 @@
+Update `mypy` static type checker to 0.790.
diff --git a/changelog.d/8606.feature b/changelog.d/8606.feature
new file mode 100644
index 0000000000..fad723c108
--- /dev/null
+++ b/changelog.d/8606.feature
@@ -0,0 +1 @@
+Limit appservice transactions to 100 persistent and 100 ephemeral events.
diff --git a/changelog.d/8609.misc b/changelog.d/8609.misc
new file mode 100644
index 0000000000..5e3f3c1993
--- /dev/null
+++ b/changelog.d/8609.misc
@@ -0,0 +1 @@
+Add type hints to profile and base handler.
diff --git a/mypy.ini b/mypy.ini
index b5db54ee3b..5e9f7b1259 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -15,8 +15,9 @@ files =
synapse/events/builder.py,
synapse/events/spamcheck.py,
synapse/federation,
- synapse/handlers/appservice.py,
+ synapse/handlers/_base.py,
synapse/handlers/account_data.py,
+ synapse/handlers/appservice.py,
synapse/handlers/auth.py,
synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py,
@@ -32,6 +33,7 @@ files =
synapse/handlers/pagination.py,
synapse/handlers/password_policy.py,
synapse/handlers/presence.py,
+ synapse/handlers/profile.py,
synapse/handlers/read_marker.py,
synapse/handlers/room.py,
synapse/handlers/room_member.py,
diff --git a/setup.py b/setup.py
index 494f50239f..2f4a3170d2 100755
--- a/setup.py
+++ b/setup.py
@@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8",
]
-CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope"]
+CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
# Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests.
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index ad3c408519..58291afc22 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -60,6 +60,13 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+# Maximum number of events to provide in an AS transaction.
+MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
+
+# Maximum number of ephemeral events to provide in an AS transaction.
+MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
+
+
class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
@@ -136,10 +143,17 @@ class _ServiceQueuer:
self.requests_in_flight.add(service.id)
try:
while True:
- events = self.queued_events.pop(service.id, [])
- ephemeral = self.queued_ephemeral.pop(service.id, [])
+ all_events = self.queued_events.get(service.id, [])
+ events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+ del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
+
+ all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
+ ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+ del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
+
if not events and not ephemeral:
return
+
try:
await self.txn_ctrl.send(service, events, ephemeral)
except Exception:
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 0206320e96..bd8e71ae56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
import synapse.state
import synapse.storage
@@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -30,11 +34,7 @@ class BaseHandler:
Common base class for the event handlers.
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -56,7 +56,7 @@ class BaseHandler:
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
- )
+ ) # type: Optional[Ratelimiter]
else:
self.admin_redaction_ratelimiter = None
@@ -127,15 +127,15 @@ class BaseHandler:
if guest_access != "can_join":
if context:
current_state_ids = await context.get_current_state_ids()
- current_state = await self.store.get_events(
+ current_state_dict = await self.store.get_events(
list(current_state_ids.values())
)
+ current_state = list(current_state_dict.values())
else:
- current_state = await self.state_handler.get_current_state(
+ current_state_map = await self.state_handler.get_current_state(
event.room_id
)
-
- current_state = list(current_state.values())
+ current_state = list(current_state_map.values())
logger.info("maybe_kick_guest_users %r", current_state)
await self.kick_guest_users(current_state)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 98075f48d2..cb11754bf8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
+ # The member_event_id will always be available if membership is set
+ # to leave.
+ assert member_event_id
+
result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
@@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
- membership: Membership,
+ membership: str,
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
@@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str,
room_id: str,
pagin_config: PaginationConfig,
- membership: Membership,
+ membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index b78a12ad01..92700b589c 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import random
+from typing import TYPE_CHECKING, Optional
from synapse.api.errors import (
AuthError,
@@ -25,10 +25,19 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.types import UserID, create_requester, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ Requester,
+ UserID,
+ create_requester,
+ get_domain_from_id,
+)
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256
@@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.federation = hs.get_federation_client()
@@ -60,7 +69,7 @@ class ProfileHandler(BaseHandler):
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
- async def get_profile(self, user_id):
+ async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
@@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- async def get_profile_from_cache(self, user_id):
+ async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
@@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}
- async def get_displayname(self, target_user):
+ async def get_displayname(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
@@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
return result["displayname"]
async def set_displayname(
- self, target_user, requester, new_displayname, by_admin=False
- ):
+ self,
+ target_user: UserID,
+ requester: Requester,
+ new_displayname: str,
+ by_admin: bool = False,
+ ) -> None:
"""Set the displayname of a user
Args:
- target_user (UserID): the user whose displayname is to be changed.
- requester (Requester): The user attempting to make this change.
- new_displayname (str): The displayname to give this user.
- by_admin (bool): Whether this change was made by an administrator.
+ target_user: the user whose displayname is to be changed.
+ requester: The user attempting to make this change.
+ new_displayname: The displayname to give this user.
+ by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
+ displayname_to_set = new_displayname # type: Optional[str]
if new_displayname == "":
- new_displayname = None
+ displayname_to_set = None
# If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms.
@@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
- await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ await self.store.set_profile_displayname(
+ target_user.localpart, displayname_to_set
+ )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
@@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
- async def get_avatar_url(self, target_user):
+ async def get_avatar_url(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user):
try:
avatar_url = await self.store.get_profile_avatar_url(
@@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
return result["avatar_url"]
async def set_avatar_url(
- self, target_user, requester, new_avatar_url, by_admin=False
+ self,
+ target_user: UserID,
+ requester: Requester,
+ new_avatar_url: str,
+ by_admin: bool = False,
):
"""Set a new avatar URL for a user.
Args:
- target_user (UserID): the user whose avatar URL is to be changed.
- requester (Requester): The user attempting to make this change.
- new_avatar_url (str): The avatar URL to give this user.
- by_admin (bool): Whether this change was made by an administrator.
+ target_user: the user whose avatar URL is to be changed.
+ requester: The user attempting to make this change.
+ new_avatar_url: The avatar URL to give this user.
+ by_admin: Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
- async def on_profile_query(self, args):
+ async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
return response
- async def _update_join_states(self, requester, target_user):
+ async def _update_join_states(
+ self, requester: Requester, target_user: UserID
+ ) -> None:
if not self.hs.is_mine(target_user):
return
@@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e)
)
- async def check_profile_query_allowed(self, target_user, requester=None):
+ async def check_profile_query_allowed(
+ self, target_user: UserID, requester: Optional[UserID] = None
+ ) -> None:
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
Args:
- target_user (UserID): The owner of the queried profile.
- requester (None|UserID): The user querying for the profile.
+ target_user: The owner of the queried profile.
+ requester: The user querying for the profile.
Raises:
SynapseError(403): The two users share no room, or ne user couldn't
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
new file mode 100644
index 0000000000..0caf325916
--- /dev/null
+++ b/synapse/logging/_remote.py
@@ -0,0 +1,225 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import traceback
+from collections import deque
+from ipaddress import IPv4Address, IPv6Address, ip_address
+from math import floor
+from typing import Callable, Optional
+
+import attr
+from zope.interface import implementer
+
+from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
+from twisted.internet.endpoints import (
+ HostnameEndpoint,
+ TCP4ClientEndpoint,
+ TCP6ClientEndpoint,
+)
+from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.protocol import Factory, Protocol
+from twisted.logger import ILogObserver, Logger, LogLevel
+
+
+@attr.s
+@implementer(IPushProducer)
+class LogProducer:
+ """
+ An IPushProducer that writes logs from its buffer to its transport when it
+ is resumed.
+
+ Args:
+ buffer: Log buffer to read logs from.
+ transport: Transport to write to.
+ format_event: A callable to format the log entry to a string.
+ """
+
+ transport = attr.ib(type=ITransport)
+ format_event = attr.ib(type=Callable[[dict], str])
+ _buffer = attr.ib(type=deque)
+ _paused = attr.ib(default=False, type=bool, init=False)
+
+ def pauseProducing(self):
+ self._paused = True
+
+ def stopProducing(self):
+ self._paused = True
+ self._buffer = deque()
+
+ def resumeProducing(self):
+ self._paused = False
+
+ while self._paused is False and (self._buffer and self.transport.connected):
+ try:
+ # Request the next event and format it.
+ event = self._buffer.popleft()
+ msg = self.format_event(event)
+
+ # Send it as a new line over the transport.
+ self.transport.write(msg.encode("utf8"))
+ except Exception:
+ # Something has gone wrong writing to the transport -- log it
+ # and break out of the while.
+ traceback.print_exc(file=sys.__stderr__)
+ break
+
+
+@attr.s
+@implementer(ILogObserver)
+class TCPLogObserver:
+ """
+ An IObserver that writes JSON logs to a TCP target.
+
+ Args:
+ hs (HomeServer): The homeserver that is being logged for.
+ host: The host of the logging target.
+ port: The logging target's port.
+ format_event: A callable to format the log entry to a string.
+ maximum_buffer: The maximum buffer size.
+ """
+
+ hs = attr.ib()
+ host = attr.ib(type=str)
+ port = attr.ib(type=int)
+ format_event = attr.ib(type=Callable[[dict], str])
+ maximum_buffer = attr.ib(type=int)
+ _buffer = attr.ib(default=attr.Factory(deque), type=deque)
+ _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
+ _logger = attr.ib(default=attr.Factory(Logger))
+ _producer = attr.ib(default=None, type=Optional[LogProducer])
+
+ def start(self) -> None:
+
+ # Connect without DNS lookups if it's a direct IP.
+ try:
+ ip = ip_address(self.host)
+ if isinstance(ip, IPv4Address):
+ endpoint = TCP4ClientEndpoint(
+ self.hs.get_reactor(), self.host, self.port
+ )
+ elif isinstance(ip, IPv6Address):
+ endpoint = TCP6ClientEndpoint(
+ self.hs.get_reactor(), self.host, self.port
+ )
+ else:
+ raise ValueError("Unknown IP address provided: %s" % (self.host,))
+ except ValueError:
+ endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+
+ factory = Factory.forProtocol(Protocol)
+ self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+ self._service.startService()
+ self._connect()
+
+ def stop(self):
+ self._service.stopService()
+
+ def _connect(self) -> None:
+ """
+ Triggers an attempt to connect then write to the remote if not already writing.
+ """
+ if self._connection_waiter:
+ return
+
+ self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+ @self._connection_waiter.addErrback
+ def fail(r):
+ r.printTraceback(file=sys.__stderr__)
+ self._connection_waiter = None
+ self._connect()
+
+ @self._connection_waiter.addCallback
+ def writer(r):
+ # We have a connection. If we already have a producer, and its
+ # transport is the same, just trigger a resumeProducing.
+ if self._producer and r.transport is self._producer.transport:
+ self._producer.resumeProducing()
+ self._connection_waiter = None
+ return
+
+ # If the producer is still producing, stop it.
+ if self._producer:
+ self._producer.stopProducing()
+
+ # Make a new producer and start it.
+ self._producer = LogProducer(
+ buffer=self._buffer,
+ transport=r.transport,
+ format_event=self.format_event,
+ )
+ r.transport.registerProducer(self._producer, True)
+ self._producer.resumeProducing()
+ self._connection_waiter = None
+
+ def _handle_pressure(self) -> None:
+ """
+ Handle backpressure by shedding events.
+
+ The buffer will, in this order, until the buffer is below the maximum:
+ - Shed DEBUG events
+ - Shed INFO events
+ - Shed the middle 50% of the events.
+ """
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Strip out DEBUGs
+ self._buffer = deque(
+ filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
+ )
+
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Strip out INFOs
+ self._buffer = deque(
+ filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
+ )
+
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Cut the middle entries out
+ buffer_split = floor(self.maximum_buffer / 2)
+
+ old_buffer = self._buffer
+ self._buffer = deque()
+
+ for i in range(buffer_split):
+ self._buffer.append(old_buffer.popleft())
+
+ end_buffer = []
+ for i in range(buffer_split):
+ end_buffer.append(old_buffer.pop())
+
+ self._buffer.extend(reversed(end_buffer))
+
+ def __call__(self, event: dict) -> None:
+ self._buffer.append(event)
+
+ # Handle backpressure, if it exists.
+ try:
+ self._handle_pressure()
+ except Exception:
+ # If handling backpressure fails,clear the buffer and log the
+ # exception.
+ self._buffer.clear()
+ self._logger.failure("Failed clearing backpressure")
+
+ # Try and write immediately.
+ self._connect()
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 1b8916cfa2..9b46956ca9 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -18,26 +18,11 @@ Log formatters that output terse JSON.
"""
import json
-import sys
-import traceback
-from collections import deque
-from ipaddress import IPv4Address, IPv6Address, ip_address
-from math import floor
-from typing import IO, Optional
+from typing import IO
-import attr
-from zope.interface import implementer
+from twisted.logger import FileLogObserver
-from twisted.application.internet import ClientService
-from twisted.internet.defer import Deferred
-from twisted.internet.endpoints import (
- HostnameEndpoint,
- TCP4ClientEndpoint,
- TCP6ClientEndpoint,
-)
-from twisted.internet.interfaces import IPushProducer, ITransport
-from twisted.internet.protocol import Factory, Protocol
-from twisted.logger import FileLogObserver, ILogObserver, Logger
+from synapse.logging._remote import TCPLogObserver
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
@@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
return FileLogObserver(outFile, formatEvent)
-@attr.s
-@implementer(IPushProducer)
-class LogProducer:
+def TerseJSONToTCPLogObserver(
+ hs, host: str, port: int, metadata: dict, maximum_buffer: int
+) -> FileLogObserver:
"""
- An IPushProducer that writes logs from its buffer to its transport when it
- is resumed.
-
- Args:
- buffer: Log buffer to read logs from.
- transport: Transport to write to.
- """
-
- transport = attr.ib(type=ITransport)
- _buffer = attr.ib(type=deque)
- _paused = attr.ib(default=False, type=bool, init=False)
-
- def pauseProducing(self):
- self._paused = True
-
- def stopProducing(self):
- self._paused = True
- self._buffer = deque()
-
- def resumeProducing(self):
- self._paused = False
-
- while self._paused is False and (self._buffer and self.transport.connected):
- try:
- event = self._buffer.popleft()
- self.transport.write(_encoder.encode(event).encode("utf8"))
- self.transport.write(b"\n")
- except Exception:
- # Something has gone wrong writing to the transport -- log it
- # and break out of the while.
- traceback.print_exc(file=sys.__stderr__)
- break
-
-
-@attr.s
-@implementer(ILogObserver)
-class TerseJSONToTCPLogObserver:
- """
- An IObserver that writes JSON logs to a TCP target.
+ A log observer that formats events to a flattened JSON representation.
Args:
hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
- metadata: Metadata to be added to each log entry.
+ metadata: Metadata to be added to each log object.
+ maximum_buffer: The maximum buffer size.
"""
- hs = attr.ib()
- host = attr.ib(type=str)
- port = attr.ib(type=int)
- metadata = attr.ib(type=dict)
- maximum_buffer = attr.ib(type=int)
- _buffer = attr.ib(default=attr.Factory(deque), type=deque)
- _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
- _logger = attr.ib(default=attr.Factory(Logger))
- _producer = attr.ib(default=None, type=Optional[LogProducer])
-
- def start(self) -> None:
-
- # Connect without DNS lookups if it's a direct IP.
- try:
- ip = ip_address(self.host)
- if isinstance(ip, IPv4Address):
- endpoint = TCP4ClientEndpoint(
- self.hs.get_reactor(), self.host, self.port
- )
- elif isinstance(ip, IPv6Address):
- endpoint = TCP6ClientEndpoint(
- self.hs.get_reactor(), self.host, self.port
- )
- except ValueError:
- endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
-
- factory = Factory.forProtocol(Protocol)
- self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
- self._service.startService()
- self._connect()
-
- def stop(self):
- self._service.stopService()
-
- def _connect(self) -> None:
- """
- Triggers an attempt to connect then write to the remote if not already writing.
- """
- if self._connection_waiter:
- return
-
- self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
- @self._connection_waiter.addErrback
- def fail(r):
- r.printTraceback(file=sys.__stderr__)
- self._connection_waiter = None
- self._connect()
-
- @self._connection_waiter.addCallback
- def writer(r):
- # We have a connection. If we already have a producer, and its
- # transport is the same, just trigger a resumeProducing.
- if self._producer and r.transport is self._producer.transport:
- self._producer.resumeProducing()
- self._connection_waiter = None
- return
-
- # If the producer is still producing, stop it.
- if self._producer:
- self._producer.stopProducing()
-
- # Make a new producer and start it.
- self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
- r.transport.registerProducer(self._producer, True)
- self._producer.resumeProducing()
- self._connection_waiter = None
-
- def _handle_pressure(self) -> None:
- """
- Handle backpressure by shedding events.
-
- The buffer will, in this order, until the buffer is below the maximum:
- - Shed DEBUG events
- - Shed INFO events
- - Shed the middle 50% of the events.
- """
- if len(self._buffer) <= self.maximum_buffer:
- return
-
- # Strip out DEBUGs
- self._buffer = deque(
- filter(lambda event: event["level"] != "DEBUG", self._buffer)
- )
-
- if len(self._buffer) <= self.maximum_buffer:
- return
-
- # Strip out INFOs
- self._buffer = deque(
- filter(lambda event: event["level"] != "INFO", self._buffer)
- )
-
- if len(self._buffer) <= self.maximum_buffer:
- return
-
- # Cut the middle entries out
- buffer_split = floor(self.maximum_buffer / 2)
-
- old_buffer = self._buffer
- self._buffer = deque()
-
- for i in range(buffer_split):
- self._buffer.append(old_buffer.popleft())
-
- end_buffer = []
- for i in range(buffer_split):
- end_buffer.append(old_buffer.pop())
-
- self._buffer.extend(reversed(end_buffer))
-
- def __call__(self, event: dict) -> None:
- flattened = flatten_event(event, self.metadata, include_time=True)
- self._buffer.append(flattened)
-
- # Handle backpressure, if it exists.
- try:
- self._handle_pressure()
- except Exception:
- # If handling backpressure fails,clear the buffer and log the
- # exception.
- self._buffer.clear()
- self._logger.failure("Failed clearing backpressure")
+ def formatEvent(_event: dict) -> str:
+ flattened = flatten_event(_event, metadata, include_time=True)
+ return _encoder.encode(flattened) + "\n"
- # Try and write immediately.
- self._connect()
+ return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 1681caa1f0..a6d1eb908a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
- self, user_localpart: str, new_displayname: str
+ self, user_localpart: str, new_displayname: Optional[str]
) -> None:
await self.db_pool.simple_update_one(
table="profiles",
@@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
- ) -> Dict[str, str]:
+ ) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`
"""
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 2acb8b7603..97f8cad0dd 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -260,6 +260,31 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
+ def test_send_large_txns(self):
+ srv_1_defer = defer.Deferred()
+ srv_2_defer = defer.Deferred()
+ send_return_list = [srv_1_defer, srv_2_defer]
+
+ def do_send(x, y, z):
+ return make_deferred_yieldable(send_return_list.pop(0))
+
+ self.txn_ctrl.send = Mock(side_effect=do_send)
+
+ service = Mock(id=4, name="service")
+ event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
+ for event in event_list:
+ self.queuer.enqueue_event(service, event)
+
+ # Expect the first event to be sent immediately.
+ self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
+ srv_1_defer.callback(service)
+ # Then send the next 100 events
+ self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
+ srv_2_defer.callback(service)
+ # Then the final 99 events
+ self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
+ self.assertEquals(3, self.txn_ctrl.send.call_count)
+
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
@@ -296,3 +321,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
self.assertEquals(2, self.txn_ctrl.send.call_count)
+
+ def test_send_large_txns_ephemeral(self):
+ d = defer.Deferred()
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y, z: make_deferred_yieldable(d)
+ )
+ # Expect the event to be sent immediately.
+ service = Mock(id=4, name="service")
+ first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
+ second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
+ event_list = first_chunk + second_chunk
+ self.queuer.enqueue_ephemeral(service, event_list)
+ self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
+ d.callback(service)
+ self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
+ self.assertEquals(2, self.txn_ctrl.send.call_count)
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 4cf81f7128..fd128b88e0 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -78,7 +78,7 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"server_name",
"name",
]
- self.assertEqual(set(log.keys()), set(expected_log_keys))
+ self.assertCountEqual(log.keys(), expected_log_keys)
# It contains the data we expect.
self.assertEqual(log["name"], "wally")
|