diff --git a/changelog.d/9563.misc b/changelog.d/9563.misc
new file mode 100644
index 0000000000..7a3493e4a1
--- /dev/null
+++ b/changelog.d/9563.misc
@@ -0,0 +1 @@
+Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.
diff --git a/changelog.d/9587.bugfix b/changelog.d/9587.bugfix
new file mode 100644
index 0000000000..d8f04c4f21
--- /dev/null
+++ b/changelog.d/9587.bugfix
@@ -0,0 +1 @@
+Re-Activating account with admin API when local passwords are disabled.
\ No newline at end of file
diff --git a/changelog.d/9590.misc b/changelog.d/9590.misc
new file mode 100644
index 0000000000..186396c45b
--- /dev/null
+++ b/changelog.d/9590.misc
@@ -0,0 +1 @@
+Add logging for redis connection setup.
diff --git a/changelog.d/9591.misc b/changelog.d/9591.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9591.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/changelog.d/9596.misc b/changelog.d/9596.misc
new file mode 100644
index 0000000000..fc19a95f75
--- /dev/null
+++ b/changelog.d/9596.misc
@@ -0,0 +1 @@
+Improve logging when processing incoming transactions.
diff --git a/changelog.d/9597.bugfix b/changelog.d/9597.bugfix
new file mode 100644
index 0000000000..349dc9d664
--- /dev/null
+++ b/changelog.d/9597.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages.
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 618548a305..34787e0b1e 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -17,6 +17,8 @@
"""
from typing import Any, List, Optional, Type, Union
+from twisted.internet import protocol
+
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
@@ -52,7 +54,7 @@ def lazyConnection(
class ConnectionHandler: ...
-class RedisFactory:
+class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 3370bc74cf..8d0f6b7b31 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -164,7 +164,7 @@ class Auth:
async def get_user_by_req(
self,
- request: Request,
+ request: SynapseRequest,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5e8b86bc96..8206f65b5e 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -113,10 +113,11 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
- self._federation_ratelimiter = hs.get_federation_ratelimiter()
-
self._server_linearizer = Linearizer("fed_server")
- self._transaction_linearizer = Linearizer("fed_txn_handler")
+
+ # origins that we are currently processing a transaction from.
+ # a dict from origin to txn id.
+ self._active_transactions = {} # type: Dict[str, str]
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
@@ -170,6 +171,33 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
+ # Reject malformed transactions early: reject if too many PDUs/EDUs
+ if len(transaction.pdus) > 50 or ( # type: ignore
+ hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
+ ):
+ logger.info("Transaction PDU or EDU count too large. Returning 400")
+ return 400, {}
+
+ # we only process one transaction from each origin at a time. We need to do
+ # this check here, rather than in _on_incoming_transaction_inner so that we
+ # don't cache the rejection in _transaction_resp_cache (so that if the txn
+ # arrives again later, we can process it).
+ current_transaction = self._active_transactions.get(origin)
+ if current_transaction and current_transaction != transaction_id:
+ logger.warning(
+ "Received another txn %s from %s while still processing %s",
+ transaction_id,
+ origin,
+ current_transaction,
+ )
+ return 429, {
+ "errcode": Codes.UNKNOWN,
+ "error": "Too many concurrent transactions",
+ }
+
+ # CRITICAL SECTION: we must now not await until we populate _active_transactions
+ # in _on_incoming_transaction_inner.
+
# We wrap in a ResponseCache so that we de-duplicate retried
# transactions.
return await self._transaction_resp_cache.wrap(
@@ -183,26 +211,18 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
- # Use a linearizer to ensure that transactions from a remote are
- # processed in order.
- with await self._transaction_linearizer.queue(origin):
- # We rate limit here *after* we've queued up the incoming requests,
- # so that we don't fill up the ratelimiter with blocked requests.
- #
- # This is important as the ratelimiter allows N concurrent requests
- # at a time, and only starts ratelimiting if there are more requests
- # than that being processed at a time. If we queued up requests in
- # the linearizer/response cache *after* the ratelimiting then those
- # queued up requests would count as part of the allowed limit of N
- # concurrent requests.
- with self._federation_ratelimiter.ratelimit(origin) as d:
- await d
-
- result = await self._handle_incoming_transaction(
- origin, transaction, request_time
- )
+ # CRITICAL SECTION: the first thing we must do (before awaiting) is
+ # add an entry to _active_transactions.
+ assert origin not in self._active_transactions
+ self._active_transactions[origin] = transaction.transaction_id # type: ignore
- return result
+ try:
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
+ )
+ return result
+ finally:
+ del self._active_transactions[origin]
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
@@ -228,19 +248,6 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
- # Reject if PDU count > 50 or EDU count > 100
- if len(transaction.pdus) > 50 or ( # type: ignore
- hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
- ):
-
- logger.info("Transaction PDU or EDU count too large. Returning 400")
-
- response = {}
- await self.transaction_actions.set_response(
- origin, transaction, 400, response
- )
- return 400, response
-
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
@@ -336,34 +343,41 @@ class FederationServer(FederationBase):
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
- logger.debug("Processing PDUs for %s", room_id)
- try:
- await self.check_server_matches_acl(origin_host, room_id)
- except AuthError as e:
- logger.warning("Ignoring PDUs for room %s from banned server", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- pdu_results[event_id] = e.error_dict()
- return
+ with nested_logging_context(room_id):
+ logger.debug("Processing PDUs for %s", room_id)
- for pdu in pdus_by_room[room_id]:
- event_id = pdu.event_id
- with pdu_process_time.time():
- with nested_logging_context(event_id):
- try:
- await self._handle_received_pdu(origin, pdu)
- pdu_results[event_id] = {}
- except FederationError as e:
- logger.warning("Error handling PDU %s: %s", event_id, e)
- pdu_results[event_id] = {"error": str(e)}
- except Exception as e:
- f = failure.Failure()
- pdu_results[event_id] = {"error": str(e)}
- logger.error(
- "Failed to handle PDU %s",
- event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
- )
+ try:
+ await self.check_server_matches_acl(origin_host, room_id)
+ except AuthError as e:
+ logger.warning(
+ "Ignoring PDUs for room %s from banned server", room_id
+ )
+ for pdu in pdus_by_room[room_id]:
+ event_id = pdu.event_id
+ pdu_results[event_id] = e.error_dict()
+ return
+
+ for pdu in pdus_by_room[room_id]:
+ pdu_results[pdu.event_id] = await process_pdu(pdu)
+
+ async def process_pdu(pdu: EventBase) -> JsonDict:
+ event_id = pdu.event_id
+ with pdu_process_time.time():
+ with nested_logging_context(event_id):
+ try:
+ await self._handle_received_pdu(origin, pdu)
+ return {}
+ except FederationError as e:
+ logger.warning("Error handling PDU %s: %s", event_id, e)
+ return {"error": str(e)}
+ except Exception as e:
+ f = failure.Failure()
+ logger.error(
+ "Failed to handle PDU %s",
+ event_id,
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
+ )
+ return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -942,7 +956,9 @@ class FederationHandlerRegistry:
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
- self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
+ self.query_handlers = (
+ {}
+ ) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
@@ -976,7 +992,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler
def register_query_handler(
- self, query_type: str, handler: Callable[[dict], defer.Deferred]
+ self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -1049,7 +1065,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
- async def on_query(self, query_type: str, args: dict):
+ async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0f10cc3dc1..19a55f0971 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -202,7 +202,7 @@ class FederationHandler(BaseHandler):
or pdu.internal_metadata.is_outlier()
)
if already_seen:
- logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
+ logger.debug("Already seen pdu")
return
# do some initial sanity-checking of the event. In particular, make
@@ -211,18 +211,14 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warning(
- "[%s %s] Received event failed sanity checks", room_id, event_id
- )
+ logger.warning("Received event failed sanity checks")
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
- "[%s %s] Queuing PDU from %s for now: join in progress",
- room_id,
- event_id,
+ "Queuing PDU from %s for now: join in progress",
origin,
)
self.room_queues[room_id].append((pdu, origin))
@@ -237,9 +233,7 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
- "[%s %s] Ignoring PDU from %s as we're not in the room",
- room_id,
- event_id,
+ "Ignoring PDU from %s as we're not in the room",
origin,
)
return None
@@ -251,7 +245,7 @@ class FederationHandler(BaseHandler):
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
+ logger.debug("min_depth: %d", min_depth)
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
@@ -268,17 +262,13 @@ class FederationHandler(BaseHandler):
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
- "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
- room_id,
- event_id,
+ "Acquiring room lock to fetch %d missing prev_events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
- "[%s %s] Acquired room lock to fetch %d missing prev_events",
- room_id,
- event_id,
+ "Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
)
@@ -298,9 +288,7 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
- "[%s %s] Found all missing prev_events",
- room_id,
- event_id,
+ "Found all missing prev_events",
)
elif missing_prevs:
logger.info(
@@ -338,9 +326,7 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warning(
- "[%s %s] Rejecting: failed to fetch %d prev events: %s",
- room_id,
- event_id,
+ "Rejecting: failed to fetch %d prev events: %s",
len(prevs - seen),
shortstr(prevs - seen),
)
@@ -416,10 +402,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
- "[%s %s] Error attempting to resolve state at missing "
- "prev_events",
- room_id,
- event_id,
+ "Error attempting to resolve state at missing " "prev_events",
exc_info=True,
)
raise FederationError(
@@ -456,9 +439,7 @@ class FederationHandler(BaseHandler):
latest |= seen
logger.info(
- "[%s %s]: Requesting missing events between %s and %s",
- room_id,
- event_id,
+ "Requesting missing events between %s and %s",
shortstr(latest),
event_id,
)
@@ -525,15 +506,11 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
- logger.warning(
- "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
- )
+ logger.warning("Failed to get prev_events: %s", e)
return
logger.info(
- "[%s %s]: Got %d prev_events: %s",
- room_id,
- event_id,
+ "Got %d prev_events: %s",
len(missing_events),
shortstr(missing_events),
)
@@ -544,9 +521,7 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
- "[%s %s] Handling received prev_event %s",
- room_id,
- event_id,
+ "Handling received prev_event %s",
ev.event_id,
)
with nested_logging_context(ev.event_id):
@@ -555,9 +530,7 @@ class FederationHandler(BaseHandler):
except FederationError as e:
if e.code == 403:
logger.warning(
- "[%s %s] Received prev_event %s failed history check.",
- room_id,
- event_id,
+ "Received prev_event %s failed history check.",
ev.event_id,
)
else:
@@ -709,10 +682,7 @@ class FederationHandler(BaseHandler):
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
- room_id = event.room_id
- event_id = event.event_id
-
- logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+ logger.debug("Processing event: %s", event)
try:
await self._handle_new_event(origin, event, state=state)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 825fadb76f..f5d1821127 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -34,6 +34,7 @@ from pymacaroons.exceptions import (
from typing_extensions import TypedDict
from twisted.web.client import readBody
+from twisted.web.http_headers import Headers
from synapse.config import ConfigError
from synapse.config.oidc_config import (
@@ -538,7 +539,7 @@ class OidcProvider:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
- headers = {
+ raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
@@ -552,10 +553,10 @@ class OidcProvider:
body = urlencode(args, True)
# Fill the body/headers with credentials
- uri, headers, body = self._client_auth.prepare(
- method="POST", uri=token_endpoint, headers=headers, body=body
+ uri, raw_headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
- headers = {k: [v] for (k, v) in headers.items()}
+ headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
diff --git a/synapse/http/client.py b/synapse/http/client.py
index af34d583ad..d4ab3a2732 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -39,6 +39,7 @@ from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
@@ -56,7 +57,13 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import (
+ UNKNOWN_LENGTH,
+ IAgent,
+ IBodyProducer,
+ IPolicyForHTTPS,
+ IResponse,
+)
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -151,16 +158,17 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
-
- r = recv()
addresses = [] # type: List[IAddress]
def _callback() -> None:
- r.resolutionBegan(None)
-
has_bad_ip = False
- for i in addresses:
- ip_address = IPAddress(i.host)
+ for address in addresses:
+ # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+ # should go through this path.
+ if not isinstance(address, (IPv4Address, IPv6Address)):
+ continue
+
+ ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
@@ -175,15 +183,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
- for i in addresses:
- r.addressResolved(i)
- r.resolutionComplete()
+ for address in addresses:
+ recv.addressResolved(address)
+ recv.resolutionComplete()
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
- pass
+ recv.resolutionBegan(resolutionInProgress)
@staticmethod
def addressResolved(address: IAddress) -> None:
@@ -197,7 +205,7 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber
)
- return r
+ return recv
@implementer(ISynapseReactor)
@@ -346,7 +354,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
- )
+ ) # type: IAgent
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@@ -868,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")
+@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 174ca7be5a..643492ceaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""
- transport = attr.ib(type=ITransport)
+ # This is essentially ITCPTransport, but that is missing certain fields
+ # (connected and registerProducer) which are part of the implementation.
+ transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter:
return
- self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
@@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect()
def writer(result: Protocol) -> None:
+ # Force recognising transport as a Connection and not the more
+ # generic ITransport.
+ transport = result.transport # type: Connection # type: ignore
+
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
- if self._producer and result.transport is self._producer.transport:
+ if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
- transport=result.transport,
+ transport=transport,
format=self.format,
)
- result.transport.registerProducer(self._producer, True)
+ transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
- self._connection_waiter.addCallbacks(writer, fail)
+ deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
+ deferred.addCallbacks(writer, fail)
+ self._connection_waiter = deferred
def _handle_pressure(self) -> None:
"""
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 5fec2aaf5d..3dc06a79e8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
-from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, ThrottleParams
@@ -66,7 +66,7 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey
- self.timed_call = None # type: Optional[DelayedCall]
+ self.timed_call = None # type: Optional[IDelayedCall]
self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da152..ee909f3fc5 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
- Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+ Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
- self._connections = [] # type: List[AbstractConnection]
+ self._connections = [] # type: List[IReplicationConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+ self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue(
- self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
- conn: AbstractConnection,
+ conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate
- def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
- def send_positions_to_connection(self, conn: AbstractConnection):
+ def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
)
def on_USER_SYNC(
- self, conn: AbstractConnection, cmd: UserSyncCommand
+ self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
return None
def on_CLEAR_USER_SYNC(
- self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
- def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+ def on_FEDERATION_ACK(
+ self, conn: IReplicationConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_USER_IP(
- self, conn: AbstractConnection, cmd: UserIpCommand
+ self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
- def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
- self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
- self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
- def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+ def on_REMOTE_SERVER_UP(
+ self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
- def new_connection(self, connection: AbstractConnection):
+ def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
- def lost_connection(self, connection: AbstractConnection):
+ def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections)
def send_command(
- self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad314d..8e4734b59c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
@@ -54,6 +53,7 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from zope.interface import Interface, implementer
from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver
@@ -121,6 +121,14 @@ class ConnectionStates:
CLOSED = "closed"
+class IReplicationConnection(Interface):
+ """An interface for replication connections."""
+
+ def send_command(cmd: Command):
+ """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
-class AbstractConnection(abc.ABC):
- """An interface for replication connections."""
-
- @abc.abstractmethod
- def send_command(self, cmd: Command):
- """Send the command down the connection"""
- pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 7560706b4b..7cccde097d 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
import attr
import txredisapi
+from zope.interface import implementer
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import IAddress, IConnector
+from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
@@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
- AbstractConnection,
+ IReplicationConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
@@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
pass
-class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+@implementer(IReplicationConnection)
+class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
This class fulfils two functions:
@@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
- (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+ (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
@@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
except Exception:
logger.warning("Failed to send ping to a redis connection")
+ # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
+ # it's rubbish. We add our own here.
+
+ def startedConnecting(self, connector: IConnector):
+ logger.info(
+ "Connecting to redis server %s", format_address(connector.getDestination())
+ )
+ super().startedConnecting(connector)
+
+ def clientConnectionFailed(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s failed: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector: IConnector, reason: Failure):
+ logger.info(
+ "Connection to redis server %s lost: %s",
+ format_address(connector.getDestination()),
+ reason.value,
+ )
+ super().clientConnectionLost(connector, reason)
+
+
+def format_address(address: IAddress) -> str:
+ if isinstance(address, (IPv4Address, IPv6Address)):
+ return "%s:%i" % (address.host, address.port)
+ return str(address)
+
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index e09234c644..7681e55b58 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,10 +15,9 @@
import re
-import twisted.web.server
-
-import synapse.api.auth
+from synapse.api.auth import Auth
from synapse.api.errors import AuthError
+from synapse.http.site import SynapseRequest
from synapse.types import UserID
@@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
return patterns
-async def assert_requester_is_admin(
- auth: synapse.api.auth.Auth, request: twisted.web.server.Request
-) -> None:
+async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
"""Verify that the requester is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
request: incoming request
Raises:
@@ -53,11 +50,11 @@ async def assert_requester_is_admin(
await assert_user_is_admin(auth, requester.user)
-async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
Args:
- auth: api.auth.Auth singleton
+ auth: Auth singleton
user_id: user to check
Raises:
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 511c859f64..7fcc48a9d7 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,10 +17,9 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from twisted.web.server import Request
-
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
@@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
self.auth = hs.get_auth()
async def on_POST(
- self, request: Request, server_name: str, media_id: str
+ self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
self.media_repository = hs.get_media_repository()
async def on_DELETE(
- self, request: Request, server_name: str, media_id: str
+ self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, server_name: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 267a993430..2c89b62e25 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet):
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
- if "password" not in body:
+ if (
+ "password" not in body
+ and self.hs.config.password_localdb_enabled
+ ):
raise SynapseError(
400, "Must provide a password to re-activate an account."
)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 7aea4cebf5..5901432fad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -32,6 +32,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.types import GroupID, JsonDict
from ._base import client_patterns
@@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
- async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_POST(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, category_id: Optional[str], room_id: str
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ category_id: Optional[str],
+ room_id: str,
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, category_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_GET(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, category_id: str
+ self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_GET(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, role_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, role_id: Optional[str], user_id: str
+ self,
+ request: SynapseRequest,
+ group_id: str,
+ role_id: Optional[str],
+ user_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, role_id: str, user_id: str
+ self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
- self, request: Request, group_id: str, room_id: str
+ self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
- self, request: Request, group_id: str, room_id: str, config_key: str
+ self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
- async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id, user_id
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.store = hs.get_datastore()
@_validate_group_id
- async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+ async def on_PUT(
+ self, request: SynapseRequest, group_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9039662f7e..1eff98ef14 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a074e807dc..b8895aeaa9 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -39,6 +39,7 @@ from synapse.http.server import (
respond_with_json_bytes,
)
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
@@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_GET(self, request: Request) -> None:
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 5e104fac40..ae5aef2f7f 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
@@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_POST(self, request: Request) -> None:
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
diff --git a/synapse/server.py b/synapse/server.py
index 369cc88026..48ac87a124 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
- return (
- InsecureInterceptableContextFactory()
- if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
- else RegularPolicyForHTTPS()
- )
+ if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+ return InsecureInterceptableContextFactory()
+ return RegularPolicyForHTTPS()
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 21ecb81c99..0ce181a51e 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -16,12 +16,23 @@ from io import BytesIO
from mock import Mock
+from netaddr import IPSet
+
+from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
-from twisted.web.client import ResponseDone
+from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH
-from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+from synapse.api.errors import SynapseError
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ BlacklistingReactorWrapper,
+ BodyExceededMaxSize,
+ read_body_with_max_size,
+)
+from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")
+
+
+class BlacklistingAgentTest(TestCase):
+ def setUp(self):
+ self.reactor, self.clock = get_clock()
+
+ self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+ self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+ self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+ # Configure the reactor's DNS resolver.
+ for (domain, ip) in (
+ (self.safe_domain, self.safe_ip),
+ (self.unsafe_domain, self.unsafe_ip),
+ (self.allowed_domain, self.allowed_ip),
+ ):
+ self.reactor.lookups[domain.decode()] = ip.decode()
+ self.reactor.lookups[ip.decode()] = ip.decode()
+
+ self.ip_whitelist = IPSet([self.allowed_ip.decode()])
+ self.ip_blacklist = IPSet(["5.0.0.0/8"])
+
+ def test_reactor(self):
+ """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
+ agent = Agent(
+ BlacklistingReactorWrapper(
+ self.reactor,
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ ),
+ )
+
+ # The unsafe domains and IPs should be rejected.
+ for domain in (self.unsafe_domain, self.unsafe_ip):
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + domain), DNSLookupError
+ )
+
+ # The safe domains IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
+
+ def test_agent(self):
+ """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
+ agent = BlacklistingAgentWrapper(
+ Agent(self.reactor),
+ ip_whitelist=self.ip_whitelist,
+ ip_blacklist=self.ip_blacklist,
+ )
+
+ # The unsafe IPs should be rejected.
+ self.failureResultOf(
+ agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
+ )
+
+ # The safe and unsafe domains and safe IPs should be accepted.
+ for domain in (
+ self.safe_domain,
+ self.unsafe_domain,
+ self.allowed_domain,
+ self.safe_ip,
+ self.allowed_ip,
+ ):
+ d = agent.request(b"GET", b"http://" + domain)
+
+ # Grab the latest TCP connection.
+ (
+ host,
+ port,
+ client_factory,
+ _timeout,
+ _bindAddress,
+ ) = self.reactor.tcpClients[-1]
+
+ # Make the connection and pump data through it.
+ client = client_factory.buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
+ )
+
+ response = self.successResultOf(d)
+ self.assertEqual(response.code, 200)
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index f235f1bd83..0d9e3bb11d 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
"""
rch = self.hs.get_tcp_replication()
- # wire up the ReplicationCommandHandler to a mock connection
- mock_connection = mock.Mock(spec=AbstractConnection)
+ # wire up the ReplicationCommandHandler to a mock connection, which needs
+ # to implement IReplicationConnection. (Note that Mock doesn't understand
+ # interfaces, but casing an interface to a list gives the attributes.)
+ mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection)
# tell it it received an RDATA row
|