diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 770b921aff..5898f0788c 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,24 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distutils.util import strtobool
-
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias, UserID
-from synapse.util.stringutils import random_string_with_symbols
+from synapse.util.stringutils import random_string_with_symbols, strtobool
class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
- self.enable_registration = bool(
- strtobool(str(config.get("enable_registration", False)))
+ self.enable_registration = strtobool(
+ str(config.get("enable_registration", False))
)
if "disable_registration" in config:
- self.enable_registration = not bool(
- strtobool(str(config["disable_registration"]))
+ self.enable_registration = not strtobool(
+ str(config["disable_registration"])
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8028663fa8..3ec4120f85 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -17,7 +17,6 @@
import abc
import os
-from distutils.util import strtobool
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
+from synapse.util.stringutils import strtobool
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
@@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
# for the sake of tests, it is set here while it cannot be configured on the
# homeserver object itself.
+
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 531232cf5a..4e5ef106a0 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -764,14 +764,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.max_size = max_size
def dataReceived(self, data: bytes) -> None:
+ # If the deferred was called, bail early.
+ if self.deferred.called:
+ return
+
self.stream.write(data)
self.length += len(data)
+ # The first time the maximum size is exceeded, error and cancel the
+ # connection. dataReceived might be called again if data was received
+ # in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
- self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason: Failure) -> None:
+ # If the maximum size was already exceeded, there's nothing to do.
+ if self.deferred.called:
+ return
+
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
assert_user_is_admin,
)
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, room_id: str):
+ async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, user_id: str):
+ async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, server_name: str, media_id: str):
+ async def on_POST(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
+class ProtectMediaByID(RestServlet):
+ """Protect local media from being quarantined.
+ """
+
+ PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Protecting local media by ID: %s", media_id)
+
+ # Quarantine this media id
+ await self.store.mark_local_media_as_safe(media_id)
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_DELETE(self, request, server_name: str, media_id: str):
+ async def on_DELETE(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request, server_name: str):
+ async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
+ ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..b103c8694c 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
if len(items) <= maxitems:
return str(items)
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+ """Convert a string representation of truth to True or False
+
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+ are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
+ 'val' is anything else.
+
+ This is lifted from distutils.util.strtobool, with the exception that it actually
+ returns a bool, rather than an int.
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ elif val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ else:
+ raise ValueError("invalid truth value %r" % (val,))
|