diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index a01bac2997..4a9b0129c3 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -17,8 +17,6 @@ import sys
from synapse import python_dependencies # noqa: E402
-sys.dont_write_bytecode = True
-
logger = logging.getLogger(__name__)
try:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 9ba9f591d9..3978e41518 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -36,7 +36,7 @@ import attr
import bcrypt
import pymacaroons
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -481,7 +481,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
# Convert the URI and method to strings.
- uri = request.uri.decode("utf-8")
+ uri = request.uri.decode("utf-8") # type: ignore
method = request.method.decode("utf-8")
# If there's no session ID, create a new session.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index b6a9ce4f38..54631b4ee2 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -274,22 +274,25 @@ class PresenceHandler(BasePresenceHandler):
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
- # Start a LoopingCall in 30s that fires every 5s.
- # The initial delay is to allow disconnected clients a chance to
- # reconnect before we treat them as offline.
- def run_timeout_handler():
- return run_as_background_process(
- "handle_presence_timeouts", self._handle_timeouts
- )
-
- self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000)
+ if self._presence_enabled:
+ # Start a LoopingCall in 30s that fires every 5s.
+ # The initial delay is to allow disconnected clients a chance to
+ # reconnect before we treat them as offline.
+ def run_timeout_handler():
+ return run_as_background_process(
+ "handle_presence_timeouts", self._handle_timeouts
+ )
- def run_persister():
- return run_as_background_process(
- "persist_presence_changes", self._persist_unpersisted_changes
+ self.clock.call_later(
+ 30, self.clock.looping_call, run_timeout_handler, 5000
)
- self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
+ def run_persister():
+ return run_as_background_process(
+ "persist_presence_changes", self._persist_unpersisted_changes
+ )
+
+ self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
LaterGauge(
"synapse_handlers_presence_wheel_timer_size",
@@ -299,7 +302,7 @@ class PresenceHandler(BasePresenceHandler):
)
# Used to handle sending of presence to newly joined users/servers
- if hs.config.use_presence:
+ if self._presence_enabled:
self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 514b1f69d8..80e28bdcbe 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -31,8 +31,8 @@ from urllib.parse import urlencode
import attr
from typing_extensions import NoReturn, Protocol
-from twisted.web.http import Request
from twisted.web.iweb import IRequest
+from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 16b68d630a..6c8e361402 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -278,9 +278,8 @@ class SyncHandler:
user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(requester=requester)
- res = await self.response_cache.wrap_conditional(
+ res = await self.response_cache.wrap(
sync_config.request_key,
- lambda result: since_token != result.next_batch,
self._wait_for_sync_for_user,
sync_config,
since_token,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e54d9bd213..72901e3f95 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -289,8 +289,7 @@ class SimpleHttpClient:
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
- http_proxy: Optional[bytes] = None,
- https_proxy: Optional[bytes] = None,
+ use_proxy: bool = False,
):
"""
Args:
@@ -300,8 +299,8 @@ class SimpleHttpClient:
we may not request.
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy: proxy server to use for http connections. host[:port]
- https_proxy: proxy server to use for https connections. host[:port]
+ use_proxy: Whether proxy settings should be discovered and used
+ from conventional environment variables.
"""
self.hs = hs
@@ -345,8 +344,7 @@ class SimpleHttpClient:
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
- http_proxy=http_proxy,
- https_proxy=https_proxy,
+ use_proxy=use_proxy,
)
if self._ip_blacklist:
@@ -750,7 +748,32 @@ class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded."""
+class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
+ """A protocol which immediately errors upon receiving data."""
+
+ def __init__(self, deferred: defer.Deferred):
+ self.deferred = deferred
+
+ def _maybe_fail(self):
+ """
+ Report a max size exceed error and disconnect the first time this is called.
+ """
+ if not self.deferred.called:
+ self.deferred.errback(BodyExceededMaxSize())
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ self.transport.abortConnection()
+
+ def dataReceived(self, data: bytes) -> None:
+ self._maybe_fail()
+
+ def connectionLost(self, reason: Failure) -> None:
+ self._maybe_fail()
+
+
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
+ """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -807,13 +830,15 @@ def read_body_with_max_size(
Returns:
A Deferred which resolves to the length of the read body.
"""
+ d = defer.Deferred()
+
# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_size is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_size:
- return defer.fail(BodyExceededMaxSize())
+ response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+ return d
- d = defer.Deferred()
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b730d2c634..3d553ae236 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
import re
+from urllib.request import getproxies_environment, proxy_bypass_environment
from zope.interface import implementer
@@ -58,6 +59,9 @@ class ProxyAgent(_AgentBase):
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
non-persistent pool instance will be created.
+
+ use_proxy (bool): Whether proxy settings should be discovered and used
+ from conventional environment variables.
"""
def __init__(
@@ -68,8 +72,7 @@ class ProxyAgent(_AgentBase):
connectTimeout=None,
bindAddress=None,
pool=None,
- http_proxy=None,
- https_proxy=None,
+ use_proxy=False,
):
_AgentBase.__init__(self, reactor, pool)
@@ -84,6 +87,15 @@ class ProxyAgent(_AgentBase):
if bindAddress is not None:
self._endpoint_kwargs["bindAddress"] = bindAddress
+ http_proxy = None
+ https_proxy = None
+ no_proxy = None
+ if use_proxy:
+ proxies = getproxies_environment()
+ http_proxy = proxies["http"].encode() if "http" in proxies else None
+ https_proxy = proxies["https"].encode() if "https" in proxies else None
+ no_proxy = proxies["no"] if "no" in proxies else None
+
self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
@@ -92,6 +104,8 @@ class ProxyAgent(_AgentBase):
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
+ self.no_proxy = no_proxy
+
self._policy_for_https = contextFactory
self._reactor = reactor
@@ -139,13 +153,28 @@ class ProxyAgent(_AgentBase):
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
request_path = parsed_uri.originForm
- if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ should_skip_proxy = False
+ if self.no_proxy is not None:
+ should_skip_proxy = proxy_bypass_environment(
+ parsed_uri.host.decode(),
+ proxies={"no": self.no_proxy},
+ )
+
+ if (
+ parsed_uri.scheme == b"http"
+ and self.http_proxy_endpoint
+ and not should_skip_proxy
+ ):
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint)
endpoint = self.http_proxy_endpoint
request_path = uri
- elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ elif (
+ parsed_uri.scheme == b"https"
+ and self.https_proxy_endpoint
+ and not should_skip_proxy
+ ):
endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor,
self.https_proxy_endpoint,
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 439881be67..c10992ff51 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -15,9 +15,10 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_left_room
@@ -78,7 +79,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
- self, request: Request, room_id: str, user_id: str
+ self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@@ -86,7 +87,6 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
-
request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@@ -147,7 +147,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
- self, request: Request, invite_event_id: str
+ self, request: SynapseRequest, invite_event_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@@ -155,7 +155,6 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
-
request.requester = requester
# hopefully we're now on the master, so this won't recurse!
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index ffd3aa38f7..5996de11c3 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -20,8 +21,12 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id, device_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id, device_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
)
return 200, device
- async def on_DELETE(self, request, user_id, device_id):
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str, device_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(target_user.to_string(), device_id)
return 200, {}
- async def on_PUT(self, request, user_id, device_id):
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, device_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
@@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_GET(self, request, user_id):
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
- async def on_POST(self, request, user_id):
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index fd482f0e32..381c3fe685 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -14,10 +14,16 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
@@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, report_id):
+ async def on_GET(
+ self, request: SynapseRequest, report_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
message = (
"The report_id parameter must be a string representing a positive integer."
)
try:
- report_id = int(report_id)
+ resolved_report_id = int(report_id)
except ValueError:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
- if report_id < 0:
+ if resolved_report_id < 0:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
- ret = await self.store.get_event_report(report_id)
+ ret = await self.store.get_event_report(resolved_report_id)
if not ret:
raise NotFoundError("Event report not found")
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index b996862c05..511c859f64 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,7 +17,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 1a3a36f6cf..f2c42a0f30 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -44,6 +44,48 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class ResolveRoomIdMixin:
+ def __init__(self, hs: "HomeServer"):
+ self.room_member_handler = hs.get_room_member_handler()
+
+ async def resolve_room_id(
+ self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
+ ) -> Tuple[str, Optional[List[str]]]:
+ """
+ Resolve a room identifier to a room ID, if necessary.
+
+ This also performanes checks to ensure the room ID is of the proper form.
+
+ Args:
+ room_identifier: The room ID or alias.
+ remote_room_hosts: The potential remote room hosts to use.
+
+ Returns:
+ The resolved room ID.
+
+ Raises:
+ SynapseError if the room ID is of the wrong form.
+ """
+ if RoomID.is_valid(room_identifier):
+ resolved_room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ (
+ room_id,
+ remote_room_hosts,
+ ) = await self.room_member_handler.lookup_room_alias(room_alias)
+ resolved_room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+ if not resolved_room_id:
+ raise SynapseError(
+ 400, "Unknown room ID or room alias %s" % room_identifier
+ )
+ return resolved_room_id, remote_room_hosts
+
+
class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed
@@ -334,14 +376,14 @@ class RoomStateRestServlet(RestServlet):
return 200, ret
-class JoinRoomAliasServlet(RestServlet):
+class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
@@ -362,22 +404,16 @@ class JoinRoomAliasServlet(RestServlet):
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
- if RoomID.is_valid(room_identifier):
- room_id = room_identifier
- try:
- remote_room_hosts = [
- x.decode("ascii") for x in request.args[b"server_name"]
- ] # type: Optional[List[str]]
- except Exception:
- remote_room_hosts = None
- elif RoomAlias.is_valid(room_identifier):
- handler = self.room_member_handler
- room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
- else:
- raise SynapseError(
- 400, "%s was not legal room ID or room alias" % (room_identifier,)
- )
+ # Get the room ID from the identifier.
+ try:
+ remote_room_hosts = [
+ x.decode("ascii") for x in request.args[b"server_name"]
+ ] # type: Optional[List[str]]
+ except Exception:
+ remote_room_hosts = None
+ room_id, remote_room_hosts = await self.resolve_room_id(
+ room_identifier, remote_room_hosts
+ )
fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
@@ -412,7 +448,7 @@ class JoinRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id}
-class MakeRoomAdminRestServlet(RestServlet):
+class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get power in a room if a local user has power in
a room. Will also invite the user if they're not in the room and it's a
private room. Can specify another user (rather than the admin user) to be
@@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.room_member_handler = hs.get_room_member_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
- async def on_POST(self, request, room_identifier):
+ async def on_POST(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request, allow_empty_body=True)
- # Resolve to a room ID, if necessary.
- if RoomID.is_valid(room_identifier):
- room_id = room_identifier
- elif RoomAlias.is_valid(room_identifier):
- room_alias = RoomAlias.from_string(room_identifier)
- room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
- else:
- raise SynapseError(
- 400, "%s was not legal room ID or room alias" % (room_identifier,)
- )
+ room_id, _ = await self.resolve_room_id(room_identifier)
# Which user to grant room admin rights to.
user_to_add = content.get("user_id", requester.user.to_string())
@@ -556,7 +584,7 @@ class MakeRoomAdminRestServlet(RestServlet):
return 200, {}
-class ForwardExtremitiesRestServlet(RestServlet):
+class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get or clear forward extremities.
Clearing does not require restarting the server.
@@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.room_member_handler = hs.get_room_member_handler()
self.store = hs.get_datastore()
- async def resolve_room_id(self, room_identifier: str) -> str:
- """Resolve to a room ID, if necessary."""
- if RoomID.is_valid(room_identifier):
- resolved_room_id = room_identifier
- elif RoomAlias.is_valid(room_identifier):
- room_alias = RoomAlias.from_string(room_identifier)
- room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
- resolved_room_id = room_id.to_string()
- else:
- raise SynapseError(
- 400, "%s was not legal room ID or room alias" % (room_identifier,)
- )
- if not resolved_room_id:
- raise SynapseError(
- 400, "Unknown room ID or room alias %s" % room_identifier
- )
- return resolved_room_id
-
- async def on_DELETE(self, request, room_identifier):
+ async def on_DELETE(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
- room_id = await self.resolve_room_id(room_identifier)
+ room_id, _ = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
return 200, {"deleted": deleted_count}
- async def on_GET(self, request, room_identifier):
+ async def on_GET(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
- room_id = await self.resolve_room_id(room_identifier)
+ room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities}
@@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await assert_user_is_admin(self.auth, requester.user)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index d3434225cb..7aea4cebf5 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -18,7 +18,7 @@ import logging
from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.constants import (
MAX_GROUP_CATEGORYID_LENGTH,
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 90bbeca679..6366947071 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -21,7 +21,7 @@ from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@@ -49,18 +49,20 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
+ # The type on postpath seems incorrect in Twisted 21.2.0.
+ postpath = request.postpath # type: List[bytes] # type: ignore
+ assert postpath
+
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
- server_name, media_id = request.postpath[:2]
-
- if isinstance(server_name, bytes):
- server_name = server_name.decode("utf-8")
- media_id = media_id.decode("utf8")
+ server_name_bytes, media_id_bytes = postpath[:2]
+ server_name = server_name_bytes.decode("utf-8")
+ media_id = media_id_bytes.decode("utf8")
file_name = None
- if len(request.postpath) > 2:
+ if len(postpath) > 2:
try:
- file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
+ file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 4e4c6971f7..9039662f7e 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 48f4433155..8a43581f1f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 3375455c43..0641924f18 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -22,8 +22,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
-from twisted.web.http import Request
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.api.errors import (
FederationDeniedError,
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 6104ef4e46..a074e807dc 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -29,7 +29,7 @@ from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -149,8 +149,7 @@ class PreviewUrlResource(DirectServeJsonResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
- http_proxy=os.getenvb(b"http_proxy"),
- https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ use_proxy=True,
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 3ab90e9f9b..fbcd50f1e2 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -18,7 +18,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 1136277794..5e104fac40 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,9 +15,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import IO, TYPE_CHECKING
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
@@ -79,7 +79,9 @@ class UploadResource(DirectServeJsonResource):
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
- media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
+ content_type_headers = headers.getRawHeaders(b"Content-Type")
+ assert content_type_headers # for mypy
+ media_type = content_type_headers[0].decode("ascii")
else:
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
@@ -88,8 +90,9 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion
try:
+ content = request.content # type: IO # type: ignore
content_uri = await self.media_repo.create_content(
- media_type, upload_name, request.content, content_length, requester.user
+ media_type, upload_name, content, content_length, requester.user
)
except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index b2e0f93810..78ee0b5e88 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 9e4fbc0cbd..d26ce46efc 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index 96077cfcd1..51acaa9a92 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, List
-from twisted.web.http import Request
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
index dfefeb7796..f2acce2437 100644
--- a/synapse/rest/synapse/client/sso_register.py
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING
-from twisted.web.http import Request
+from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
diff --git a/synapse/server.py b/synapse/server.py
index 4b9ec7f0ae..1d4370e0ba 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -24,7 +24,6 @@
import abc
import functools
import logging
-import os
from typing import (
TYPE_CHECKING,
Any,
@@ -370,11 +369,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
An HTTP client that uses configured HTTP(S) proxies.
"""
- return SimpleHttpClient(
- self,
- http_proxy=os.getenvb(b"http_proxy"),
- https_proxy=os.getenvb(b"HTTPS_PROXY"),
- )
+ return SimpleHttpClient(self, use_proxy=True)
@cache_in_self
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
@@ -386,8 +381,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self,
ip_whitelist=self.config.ip_range_whitelist,
ip_blacklist=self.config.ip_range_blacklist,
- http_proxy=os.getenvb(b"http_proxy"),
- https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ use_proxy=True,
)
@cache_in_self
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 74219cb05e..6b608ebc9b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -39,6 +39,11 @@ class PusherWorkerStore(SQLBaseStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
+ self.db_pool.updates.register_background_update_handler(
+ "remove_deactivated_pushers",
+ self._remove_deactivated_pushers,
+ )
+
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@@ -284,6 +289,54 @@ class PusherWorkerStore(SQLBaseStore):
lock=False,
)
+ async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
+ """A background update that deletes all pushers for deactivated users.
+
+ Note that we don't proacively tell the pusherpool that we've deleted
+ these (just because its a bit off a faff to do from here), but they will
+ get cleaned up at the next restart
+ """
+
+ last_user = progress.get("last_user", "")
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT name FROM users
+ WHERE deactivated = ? and name > ?
+ ORDER BY name ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (1, last_user, batch_size))
+ users = [row[0] for row in txn]
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="pushers",
+ column="user_name",
+ iterable=users,
+ keyvalues={},
+ )
+
+ if users:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_deactivated_pushers", {"last_user": users[-1]}
+ )
+
+ return len(users)
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deactivated_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "remove_deactivated_pushers"
+ )
+
+ return number_deleted
+
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:
diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_pushers_for_deactivated_accounts.sql b/synapse/storage/databases/main/schema/delta/59/08delete_pushers_for_deactivated_accounts.sql
index 20ba4abca3..0ec6764150 100644
--- a/synapse/storage/databases/main/schema/delta/59/08delete_pushers_for_deactivated_accounts.sql
+++ b/synapse/storage/databases/main/schema/delta/59/08delete_pushers_for_deactivated_accounts.sql
@@ -14,8 +14,7 @@
*/
--- We may not have deleted all pushers for deactivated accounts. Do so now.
---
--- Note: We don't bother updating the `deleted_pushers` table as it's just use
--- to stop pushers on workers, and that will happen when they get next restarted.
-DELETE FROM pushers WHERE user_name IN (SELECT name FROM users WHERE deactivated = 1);
+-- We may not have deleted all pushers for deactivated accounts, so we set up a
+-- background job to delete them.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5908, 'remove_deactivated_pushers', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/59/09rejected_events_metadata.sql
index 9c95646281..cc9b267c7d 100644
--- a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
+++ b/synapse/storage/databases/main/schema/delta/59/09rejected_events_metadata.sql
@@ -13,5 +13,14 @@
* limitations under the License.
*/
+-- This originally was in 58/, but landed after 59/ was created, and so some
+-- servers running develop didn't run this delta. Running it again should be
+-- safe.
+--
+-- We first delete any in progress `rejected_events_metadata` background update,
+-- to ensure that we don't conflict when trying to insert the new one. (We could
+-- alternatively do an ON CONFLICT DO NOTHING, but that syntax isn't supported
+-- by older SQLite versions. Plus, this should be a rare case).
+DELETE FROM background_updates WHERE update_name = 'rejected_events_metadata';
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5828, 'rejected_events_metadata', '{}');
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 53f85195a7..32228f42ee 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
@@ -40,7 +40,6 @@ class ResponseCache(Generic[T]):
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
- self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]]
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0
@@ -102,11 +101,7 @@ class ResponseCache(Generic[T]):
self.pending_result_cache[key] = result
def remove(r):
- should_cache = all(
- func(r) for func in self.pending_conditionals.pop(key, [])
- )
-
- if self.timeout_sec and should_cache:
+ if self.timeout_sec:
self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None
)
@@ -117,31 +112,6 @@ class ResponseCache(Generic[T]):
result.addBoth(remove)
return result.observe()
- def add_conditional(self, key: T, conditional: Callable[[Any], bool]):
- self.pending_conditionals.setdefault(key, set()).add(conditional)
-
- def wrap_conditional(
- self,
- key: T,
- should_cache: Callable[[Any], bool],
- callback: "Callable[..., Any]",
- *args: Any,
- **kwargs: Any
- ) -> defer.Deferred:
- """The same as wrap(), but adds a conditional to the final execution.
-
- When the final execution completes, *all* conditionals need to return True for it to properly cache,
- else it'll not be cached in a timed fashion.
- """
-
- # See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of
- # python, adding a key immediately in the same execution thread will not cause a race condition.
- result = self.get(key)
- if not result or isinstance(result, defer.Deferred) and not result.called:
- self.add_conditional(key, should_cache)
-
- return self.wrap(key, callback, *args, **kwargs)
-
def wrap(
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
) -> defer.Deferred:
|