diff --git a/synapse/__init__.py b/synapse/__init__.py
index da52463531..919293cd80 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.40.0rc2"
+__version__ = "1.40.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index a986fdb47a..e0e24fddac 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -62,7 +62,7 @@ class JoinRules:
INVITE = "invite"
PRIVATE = "private"
# As defined for MSC3083.
- MSC3083_RESTRICTED = "restricted"
+ RESTRICTED = "restricted"
class RestrictedJoinRuleTypes:
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index bc678efe49..11280c4462 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -76,6 +76,8 @@ class RoomVersion:
# MSC2716: Adds m.room.power_levels -> content.historical field to control
# whether "insertion", "chunk", "marker" events can be sent
msc2716_historical = attr.ib(type=bool)
+ # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
+ msc2716_redactions = attr.ib(type=bool)
class RoomVersions:
@@ -92,6 +94,7 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
+ msc2716_redactions=False,
)
V2 = RoomVersion(
"2",
@@ -106,6 +109,7 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
+ msc2716_redactions=False,
)
V3 = RoomVersion(
"3",
@@ -120,6 +124,7 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
+ msc2716_redactions=False,
)
V4 = RoomVersion(
"4",
@@ -134,6 +139,7 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
+ msc2716_redactions=False,
)
V5 = RoomVersion(
"5",
@@ -148,6 +154,7 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
+ msc2716_redactions=False,
)
V6 = RoomVersion(
"6",
@@ -162,6 +169,7 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
+ msc2716_redactions=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -176,10 +184,11 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=False,
msc2716_historical=False,
+ msc2716_redactions=False,
)
- MSC3083 = RoomVersion(
- "org.matrix.msc3083.v2",
- RoomDisposition.UNSTABLE,
+ V7 = RoomVersion(
+ "7",
+ RoomDisposition.STABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
@@ -187,12 +196,13 @@ class RoomVersions:
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
- msc3083_join_rules=True,
- msc2403_knocking=False,
+ msc3083_join_rules=False,
+ msc2403_knocking=True,
msc2716_historical=False,
+ msc2716_redactions=False,
)
- V7 = RoomVersion(
- "7",
+ V8 = RoomVersion(
+ "8",
RoomDisposition.STABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
@@ -201,13 +211,29 @@ class RoomVersions:
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
- msc3083_join_rules=False,
+ msc3083_join_rules=True,
msc2403_knocking=True,
msc2716_historical=False,
+ msc2716_redactions=False,
)
MSC2716 = RoomVersion(
"org.matrix.msc2716",
- RoomDisposition.STABLE,
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=False,
+ msc2403_knocking=True,
+ msc2716_historical=True,
+ msc2716_redactions=False,
+ )
+ MSC2716v2 = RoomVersion(
+ "org.matrix.msc2716v2",
+ RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
@@ -218,6 +244,7 @@ class RoomVersions:
msc3083_join_rules=False,
msc2403_knocking=True,
msc2716_historical=True,
+ msc2716_redactions=True,
)
@@ -231,9 +258,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V5,
RoomVersions.V6,
RoomVersions.MSC2176,
- RoomVersions.MSC3083,
RoomVersions.V7,
RoomVersions.MSC2716,
+ RoomVersions.V8,
)
}
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4c92e9a2d4..c3a0c10499 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -370,10 +370,7 @@ def _is_membership_change_allowed(
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
- elif (
- room_version.msc3083_join_rules
- and join_rule == JoinRules.MSC3083_RESTRICTED
- ):
+ elif room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED:
# This is the same as public, but the event must contain a reference
# to the server who authorised the join. If the event does not contain
# the proper content it is rejected.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a0c07f62f4..b6da2f60af 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -17,7 +17,7 @@ from typing import Any, Mapping, Union
from frozendict import frozendict
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results
@@ -135,6 +135,12 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
add_fields("history_visibility")
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
add_fields("redacts")
+ elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
+ add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID)
+ elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK:
+ add_fields(EventContentFields.MSC2716_CHUNK_ID)
+ elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
+ add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 53fac1f8a3..4288ffff09 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -213,7 +213,7 @@ class EventAuthHandler:
raise AuthError(
403,
- "You do not belong to any of the required rooms to join this room.",
+ "You do not belong to any of the required rooms/spaces to join this room.",
)
async def has_restricted_join_rules(
@@ -240,7 +240,7 @@ class EventAuthHandler:
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self._store.get_event(join_rules_event_id)
- return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED
+ return join_rules_event.content.get("join_rule") == JoinRules.RESTRICTED
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 016c5df2ca..7ca14e1d84 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1184,8 +1184,7 @@ class PresenceHandler(BasePresenceHandler):
new_fields = {"state": presence}
if not ignore_status_msg:
- msg = status_msg if presence != PresenceState.OFFLINE else None
- new_fields["status_msg"] = msg
+ new_fields["status_msg"] = status_msg
if presence == PresenceState.ONLINE or (
presence == PresenceState.BUSY and self._busy_presence_enabled
@@ -1478,7 +1477,7 @@ def format_user_presence_state(
content["user_id"] = state.user_id
if state.last_active_ts:
content["last_active_ago"] = now - state.last_active_ts
- if state.status_msg and state.state != PresenceState.OFFLINE:
+ if state.status_msg:
content["status_msg"] = state.status_msg
if state.state == PresenceState.ONLINE:
content["currently_active"] = state.currently_active
@@ -1840,9 +1839,7 @@ def handle_timeout(
# don't set them as offline.
sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
- state = state.copy_and_replace(
- state=PresenceState.OFFLINE, status_msg=None
- )
+ state = state.copy_and_replace(state=PresenceState.OFFLINE)
changed = True
else:
# We expect to be poked occasionally by the other side.
@@ -1850,7 +1847,7 @@ def handle_timeout(
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
# The other side seems to have disappeared.
- state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None)
+ state = state.copy_and_replace(state=PresenceState.OFFLINE)
changed = True
return state if changed else None
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 2517f278b6..d0060f9046 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -18,7 +18,7 @@ import re
from collections import deque
from typing import (
TYPE_CHECKING,
- Collection,
+ Deque,
Dict,
Iterable,
List,
@@ -38,9 +38,12 @@ from synapse.api.constants import (
Membership,
RoomTypes,
)
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -57,16 +60,73 @@ MAX_ROOMS_PER_SPACE = 50
MAX_SERVERS_PER_SPACE = 3
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _PaginationKey:
+ """The key used to find unique pagination session."""
+
+ # The first three entries match the request parameters (and cannot change
+ # during a pagination session).
+ room_id: str
+ suggested_only: bool
+ max_depth: Optional[int]
+ # The randomly generated token.
+ token: str
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _PaginationSession:
+ """The information that is stored for pagination."""
+
+ # The time the pagination session was created, in milliseconds.
+ creation_time_ms: int
+ # The queue of rooms which are still to process.
+ room_queue: Deque["_RoomQueueEntry"]
+ # A set of rooms which have been processed.
+ processed_rooms: Set[str]
+
+
class SpaceSummaryHandler:
+ # The time a pagination session remains valid for.
+ _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
+
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
- self._auth = hs.get_auth()
self._event_auth_handler = hs.get_event_auth_handler()
self._store = hs.get_datastore()
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
+ # A map of query information to the current pagination state.
+ #
+ # TODO Allow for multiple workers to share this data.
+ # TODO Expire pagination tokens.
+ self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
+
+ # If a user tries to fetch the same page multiple times in quick succession,
+ # only process the first attempt and return its result to subsequent requests.
+ self._pagination_response_cache: ResponseCache[
+ Tuple[str, bool, Optional[int], Optional[int], Optional[str]]
+ ] = ResponseCache(
+ hs.get_clock(),
+ "get_room_hierarchy",
+ )
+
+ def _expire_pagination_sessions(self):
+ """Expire pagination session which are old."""
+ expire_before = (
+ self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
+ )
+ to_expire = []
+
+ for key, value in self._pagination_sessions.items():
+ if value.creation_time_ms < expire_before:
+ to_expire.append(key)
+
+ for key in to_expire:
+ logger.debug("Expiring pagination session id %s", key)
+ del self._pagination_sessions[key]
+
async def get_space_summary(
self,
requester: str,
@@ -92,9 +152,13 @@ class SpaceSummaryHandler:
Returns:
summary dict to return
"""
- # first of all, check that the user is in the room in question (or it's
- # world-readable)
- await self._auth.check_user_in_room_or_world_readable(room_id, requester)
+ # First of all, check that the room is accessible.
+ if not await self._is_local_room_accessible(room_id, requester):
+ raise AuthError(
+ 403,
+ "User %s not in room %s, and room previews are disabled"
+ % (requester, room_id),
+ )
# the queue of rooms to process
room_queue = deque((_RoomQueueEntry(room_id, ()),))
@@ -130,7 +194,7 @@ class SpaceSummaryHandler:
requester, None, room_id, suggested_only, max_children
)
- events: Collection[JsonDict] = []
+ events: Sequence[JsonDict] = []
if room_entry:
rooms_result.append(room_entry.room)
events = room_entry.children
@@ -158,48 +222,10 @@ class SpaceSummaryHandler:
room = room_entry.room
fed_room_id = room_entry.room_id
- # The room should only be included in the summary if:
- # a. the user is in the room;
- # b. the room is world readable; or
- # c. the user could join the room, e.g. the join rules
- # are set to public or the user is in a space that
- # has been granted access to the room.
- #
- # Note that we know the user is not in the root room (which is
- # why the remote call was made in the first place), but the user
- # could be in one of the children rooms and we just didn't know
- # about the link.
-
- # The API doesn't return the room version so assume that a
- # join rule of knock is valid.
- include_room = (
- room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
- or room.get("world_readable") is True
- )
-
- # Check if the user is a member of any of the allowed spaces
- # from the response.
- allowed_rooms = room.get("allowed_room_ids") or room.get(
- "allowed_spaces"
- )
- if (
- not include_room
- and allowed_rooms
- and isinstance(allowed_rooms, list)
- ):
- include_room = await self._event_auth_handler.is_user_in_rooms(
- allowed_rooms, requester
- )
-
- # Finally, if this isn't the requested room, check ourselves
- # if we can access the room.
- if not include_room and fed_room_id != queue_entry.room_id:
- include_room = await self._is_room_accessible(
- fed_room_id, requester, None
- )
-
# The user can see the room, include it!
- if include_room:
+ if await self._is_remote_room_accessible(
+ requester, fed_room_id, room
+ ):
# Before returning to the client, remove the allowed_room_ids
# and allowed_spaces keys.
room.pop("allowed_room_ids", None)
@@ -245,6 +271,158 @@ class SpaceSummaryHandler:
return {"rooms": rooms_result, "events": events_result}
+ async def get_room_hierarchy(
+ self,
+ requester: str,
+ requested_room_id: str,
+ suggested_only: bool = False,
+ max_depth: Optional[int] = None,
+ limit: Optional[int] = None,
+ from_token: Optional[str] = None,
+ ) -> JsonDict:
+ """
+ Implementation of the room hierarchy C-S API.
+
+ Args:
+ requester: The user ID of the user making this request.
+ requested_room_id: The room ID to start the hierarchy at (the "root" room).
+ suggested_only: Whether we should only return children with the "suggested"
+ flag set.
+ max_depth: The maximum depth in the tree to explore, must be a
+ non-negative integer.
+
+ 0 would correspond to just the root room, 1 would include just
+ the root room's children, etc.
+ limit: An optional limit on the number of rooms to return per
+ page. Must be a positive integer.
+ from_token: An optional pagination token.
+
+ Returns:
+ The JSON hierarchy dictionary.
+ """
+ # If a user tries to fetch the same page multiple times in quick succession,
+ # only process the first attempt and return its result to subsequent requests.
+ #
+ # This is due to the pagination process mutating internal state, attempting
+ # to process multiple requests for the same page will result in errors.
+ return await self._pagination_response_cache.wrap(
+ (requested_room_id, suggested_only, max_depth, limit, from_token),
+ self._get_room_hierarchy,
+ requester,
+ requested_room_id,
+ suggested_only,
+ max_depth,
+ limit,
+ from_token,
+ )
+
+ async def _get_room_hierarchy(
+ self,
+ requester: str,
+ requested_room_id: str,
+ suggested_only: bool = False,
+ max_depth: Optional[int] = None,
+ limit: Optional[int] = None,
+ from_token: Optional[str] = None,
+ ) -> JsonDict:
+ """See docstring for SpaceSummaryHandler.get_room_hierarchy."""
+
+ # First of all, check that the room is accessible.
+ if not await self._is_local_room_accessible(requested_room_id, requester):
+ raise AuthError(
+ 403,
+ "User %s not in room %s, and room previews are disabled"
+ % (requester, requested_room_id),
+ )
+
+ # If this is continuing a previous session, pull the persisted data.
+ if from_token:
+ self._expire_pagination_sessions()
+
+ pagination_key = _PaginationKey(
+ requested_room_id, suggested_only, max_depth, from_token
+ )
+ if pagination_key not in self._pagination_sessions:
+ raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
+
+ # Load the previous state.
+ pagination_session = self._pagination_sessions[pagination_key]
+ room_queue = pagination_session.room_queue
+ processed_rooms = pagination_session.processed_rooms
+ else:
+ # the queue of rooms to process
+ room_queue = deque((_RoomQueueEntry(requested_room_id, ()),))
+
+ # Rooms we have already processed.
+ processed_rooms = set()
+
+ rooms_result: List[JsonDict] = []
+
+ # Cap the limit to a server-side maximum.
+ if limit is None:
+ limit = MAX_ROOMS
+ else:
+ limit = min(limit, MAX_ROOMS)
+
+ # Iterate through the queue until we reach the limit or run out of
+ # rooms to include.
+ while room_queue and len(rooms_result) < limit:
+ queue_entry = room_queue.popleft()
+ room_id = queue_entry.room_id
+ current_depth = queue_entry.depth
+ if room_id in processed_rooms:
+ # already done this room
+ continue
+
+ logger.debug("Processing room %s", room_id)
+
+ is_in_room = await self._store.is_host_joined(room_id, self._server_name)
+ if is_in_room:
+ room_entry = await self._summarize_local_room(
+ requester,
+ None,
+ room_id,
+ suggested_only,
+ # TODO Handle max children.
+ max_children=None,
+ )
+
+ if room_entry:
+ rooms_result.append(room_entry.as_json())
+
+ # Add the child to the queue. We have already validated
+ # that the vias are a list of server names.
+ #
+ # If the current depth is the maximum depth, do not queue
+ # more entries.
+ if max_depth is None or current_depth < max_depth:
+ room_queue.extendleft(
+ _RoomQueueEntry(
+ ev["state_key"], ev["content"]["via"], current_depth + 1
+ )
+ for ev in reversed(room_entry.children)
+ )
+
+ processed_rooms.add(room_id)
+ else:
+ # TODO Federation.
+ pass
+
+ result: JsonDict = {"rooms": rooms_result}
+
+ # If there's additional data, generate a pagination token (and persist state).
+ if room_queue:
+ next_batch = random_string(24)
+ result["next_batch"] = next_batch
+ pagination_key = _PaginationKey(
+ requested_room_id, suggested_only, max_depth, next_batch
+ )
+ self._pagination_sessions[pagination_key] = _PaginationSession(
+ self._clock.time_msec(), room_queue, processed_rooms
+ )
+
+ return result
+
async def federation_space_summary(
self,
origin: str,
@@ -336,7 +514,7 @@ class SpaceSummaryHandler:
Returns:
A room entry if the room should be returned. None, otherwise.
"""
- if not await self._is_room_accessible(room_id, requester, origin):
+ if not await self._is_local_room_accessible(room_id, requester, origin):
return None
room_entry = await self._build_room_entry(room_id, for_federation=bool(origin))
@@ -438,8 +616,8 @@ class SpaceSummaryHandler:
return results
- async def _is_room_accessible(
- self, room_id: str, requester: Optional[str], origin: Optional[str]
+ async def _is_local_room_accessible(
+ self, room_id: str, requester: Optional[str], origin: Optional[str] = None
) -> bool:
"""
Calculate whether the room should be shown in the spaces summary.
@@ -550,6 +728,51 @@ class SpaceSummaryHandler:
)
return False
+ async def _is_remote_room_accessible(
+ self, requester: str, room_id: str, room: JsonDict
+ ) -> bool:
+ """
+ Calculate whether the room received over federation should be shown in the spaces summary.
+
+ It should be included if:
+
+ * The requester is joined or can join the room (per MSC3173).
+ * The history visibility is set to world readable.
+
+ Note that the local server is not in the requested room (which is why the
+ remote call was made in the first place), but the user could have access
+ due to an invite, etc.
+
+ Args:
+ requester: The user requesting the summary.
+ room_id: The room ID returned over federation.
+ room: The summary of the child room returned over federation.
+
+ Returns:
+ True if the room should be included in the spaces summary.
+ """
+ # The API doesn't return the room version so assume that a
+ # join rule of knock is valid.
+ if (
+ room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK)
+ or room.get("world_readable") is True
+ ):
+ return True
+
+ # Check if the user is a member of any of the allowed spaces
+ # from the response.
+ allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces")
+ if allowed_rooms and isinstance(allowed_rooms, list):
+ if await self._event_auth_handler.is_user_in_rooms(
+ allowed_rooms, requester
+ ):
+ return True
+
+ # Finally, check locally if we can access the room. The user might
+ # already be in the room (if it was a child room), or there might be a
+ # pending invite, etc.
+ return await self._is_local_room_accessible(room_id, requester)
+
async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict:
"""
Generate en entry suitable for the 'rooms' list in the summary response.
@@ -565,7 +788,7 @@ class SpaceSummaryHandler:
stats = await self._store.get_room_with_stats(room_id)
# currently this should be impossible because we call
- # check_user_in_room_or_world_readable on the room before we get here, so
+ # _is_local_room_accessible on the room before we get here, so
# there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
@@ -645,6 +868,7 @@ class SpaceSummaryHandler:
class _RoomQueueEntry:
room_id: str
via: Sequence[str]
+ depth: int = 0
@attr.s(frozen=True, slots=True, auto_attribs=True)
@@ -655,7 +879,12 @@ class _RoomEntry:
# An iterable of the sorted, stripped children events for children of this room.
#
# This may not include all children.
- children: Collection[JsonDict] = ()
+ children: Sequence[JsonDict] = ()
+
+ def as_json(self) -> JsonDict:
+ result = dict(self.room)
+ result["children_state"] = self.children
+ return result
def _has_valid_via(e: EventBase) -> bool:
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 17e1c5abb1..c577142268 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
+from typing import Optional
+import attr
from zope.interface import implementer
from twisted.internet import defer, protocol
@@ -21,7 +24,6 @@ from twisted.internet.error import ConnectError
from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
from twisted.web import http
-from twisted.web.http_headers import Headers
logger = logging.getLogger(__name__)
@@ -30,6 +32,22 @@ class ProxyConnectError(ConnectError):
pass
+@attr.s
+class ProxyCredentials:
+ username_password = attr.ib(type=bytes)
+
+ def as_proxy_authorization_value(self) -> bytes:
+ """
+ Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
+
+ Returns:
+ A transformation of the authentication string the encoded value for
+ a Proxy-Authorization header.
+ """
+ # Encode as base64 and prepend the authorization type
+ return b"Basic " + base64.encodebytes(self.username_password)
+
+
@implementer(IStreamClientEndpoint)
class HTTPConnectProxyEndpoint:
"""An Endpoint implementation which will send a CONNECT request to an http proxy
@@ -46,7 +64,7 @@ class HTTPConnectProxyEndpoint:
proxy_endpoint: the endpoint to use to connect to the proxy
host: hostname that we want to CONNECT to
port: port that we want to connect to
- headers: Extra HTTP headers to include in the CONNECT request
+ proxy_creds: credentials to authenticate at proxy
"""
def __init__(
@@ -55,20 +73,20 @@ class HTTPConnectProxyEndpoint:
proxy_endpoint: IStreamClientEndpoint,
host: bytes,
port: int,
- headers: Headers,
+ proxy_creds: Optional[ProxyCredentials],
):
self._reactor = reactor
self._proxy_endpoint = proxy_endpoint
self._host = host
self._port = port
- self._headers = headers
+ self._proxy_creds = proxy_creds
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
def connect(self, protocolFactory: ClientFactory):
f = HTTPProxiedClientFactory(
- self._host, self._port, protocolFactory, self._headers
+ self._host, self._port, protocolFactory, self._proxy_creds
)
d = self._proxy_endpoint.connect(f)
# once the tcp socket connects successfully, we need to wait for the
@@ -87,7 +105,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
dst_host: hostname that we want to CONNECT to
dst_port: port that we want to connect to
wrapped_factory: The original Factory
- headers: Extra HTTP headers to include in the CONNECT request
+ proxy_creds: credentials to authenticate at proxy
"""
def __init__(
@@ -95,12 +113,12 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
dst_host: bytes,
dst_port: int,
wrapped_factory: ClientFactory,
- headers: Headers,
+ proxy_creds: Optional[ProxyCredentials],
):
self.dst_host = dst_host
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
- self.headers = headers
+ self.proxy_creds = proxy_creds
self.on_connection = defer.Deferred()
def startedConnecting(self, connector):
@@ -114,7 +132,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
self.dst_port,
wrapped_protocol,
self.on_connection,
- self.headers,
+ self.proxy_creds,
)
def clientConnectionFailed(self, connector, reason):
@@ -145,7 +163,7 @@ class HTTPConnectProtocol(protocol.Protocol):
connected_deferred: a Deferred which will be callbacked with
wrapped_protocol when the CONNECT completes
- headers: Extra HTTP headers to include in the CONNECT request
+ proxy_creds: credentials to authenticate at proxy
"""
def __init__(
@@ -154,16 +172,16 @@ class HTTPConnectProtocol(protocol.Protocol):
port: int,
wrapped_protocol: Protocol,
connected_deferred: defer.Deferred,
- headers: Headers,
+ proxy_creds: Optional[ProxyCredentials],
):
self.host = host
self.port = port
self.wrapped_protocol = wrapped_protocol
self.connected_deferred = connected_deferred
- self.headers = headers
+ self.proxy_creds = proxy_creds
self.http_setup_client = HTTPConnectSetupClient(
- self.host, self.port, self.headers
+ self.host, self.port, self.proxy_creds
)
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
@@ -205,30 +223,38 @@ class HTTPConnectSetupClient(http.HTTPClient):
Args:
host: The hostname to send in the CONNECT message
port: The port to send in the CONNECT message
- headers: Extra headers to send with the CONNECT message
+ proxy_creds: credentials to authenticate at proxy
"""
- def __init__(self, host: bytes, port: int, headers: Headers):
+ def __init__(
+ self,
+ host: bytes,
+ port: int,
+ proxy_creds: Optional[ProxyCredentials],
+ ):
self.host = host
self.port = port
- self.headers = headers
+ self.proxy_creds = proxy_creds
self.on_connected = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
- # Send any additional specified headers
- for name, values in self.headers.getAllRawHeaders():
- for value in values:
- self.sendHeader(name, value)
+ # Determine whether we need to set Proxy-Authorization headers
+ if self.proxy_creds:
+ # Set a Proxy-Authorization header
+ self.sendHeader(
+ b"Proxy-Authorization",
+ self.proxy_creds.as_proxy_authorization_value(),
+ )
self.endHeaders()
def handleStatus(self, version: bytes, status: bytes, message: bytes):
logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200":
- raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+ raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
def handleEndHeaders(self):
logger.debug("End Headers")
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index c16b7f10e6..1238bfd287 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,6 +14,10 @@
import logging
import urllib.parse
from typing import Any, Generator, List, Optional
+from urllib.request import ( # type: ignore[attr-defined]
+ getproxies_environment,
+ proxy_bypass_environment,
+)
from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
@@ -30,9 +34,12 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse
from synapse.crypto.context_factory import FederationPolicyForHTTPS
-from synapse.http.client import BlacklistingAgentWrapper
+from synapse.http import proxyagent
+from synapse.http.client import BlacklistingAgentWrapper, BlacklistingReactorWrapper
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
+from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
from synapse.util import Clock
@@ -57,6 +64,14 @@ class MatrixFederationAgent:
user_agent:
The user agent header to use for federation requests.
+ ip_whitelist: Allowed IP addresses.
+
+ ip_blacklist: Disallowed IP addresses.
+
+ proxy_reactor: twisted reactor to use for connections to the proxy server
+ reactor might have some blacklisting applied (i.e. for DNS queries),
+ but we need unblocked access to the proxy.
+
_srv_resolver:
SrvResolver implementation to use for looking up SRV records. None
to use a default implementation.
@@ -71,11 +86,18 @@ class MatrixFederationAgent:
reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
+ ip_whitelist: IPSet,
ip_blacklist: IPSet,
_srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None,
):
- self._reactor = reactor
+ # proxy_reactor is not blacklisted
+ proxy_reactor = reactor
+
+ # We need to use a DNS resolver which filters out blacklisted IP
+ # addresses, to prevent DNS rebinding.
+ reactor = BlacklistingReactorWrapper(reactor, ip_whitelist, ip_blacklist)
+
self._clock = Clock(reactor)
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
@@ -83,24 +105,27 @@ class MatrixFederationAgent:
self._pool.cachedConnectionTimeout = 2 * 60
self._agent = Agent.usingEndpointFactory(
- self._reactor,
+ reactor,
MatrixHostnameEndpointFactory(
- reactor, tls_client_options_factory, _srv_resolver
+ reactor,
+ proxy_reactor,
+ tls_client_options_factory,
+ _srv_resolver,
),
pool=self._pool,
)
self.user_agent = user_agent
if _well_known_resolver is None:
- # Note that the name resolver has already been wrapped in a
- # IPBlacklistingResolver by MatrixFederationHttpClient.
_well_known_resolver = WellKnownResolver(
- self._reactor,
+ reactor,
agent=BlacklistingAgentWrapper(
- Agent(
- self._reactor,
+ ProxyAgent(
+ reactor,
+ proxy_reactor,
pool=self._pool,
contextFactory=tls_client_options_factory,
+ use_proxy=True,
),
ip_blacklist=ip_blacklist,
),
@@ -200,10 +225,12 @@ class MatrixHostnameEndpointFactory:
def __init__(
self,
reactor: IReactorCore,
+ proxy_reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: Optional[SrvResolver],
):
self._reactor = reactor
+ self._proxy_reactor = proxy_reactor
self._tls_client_options_factory = tls_client_options_factory
if srv_resolver is None:
@@ -211,9 +238,10 @@ class MatrixHostnameEndpointFactory:
self._srv_resolver = srv_resolver
- def endpointForURI(self, parsed_uri):
+ def endpointForURI(self, parsed_uri: URI):
return MatrixHostnameEndpoint(
self._reactor,
+ self._proxy_reactor,
self._tls_client_options_factory,
self._srv_resolver,
parsed_uri,
@@ -227,23 +255,45 @@ class MatrixHostnameEndpoint:
Args:
reactor: twisted reactor to use for underlying requests
+ proxy_reactor: twisted reactor to use for connections to the proxy server.
+ 'reactor' might have some blacklisting applied (i.e. for DNS queries),
+ but we need unblocked access to the proxy.
tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
srv_resolver: The SRV resolver to use
parsed_uri: The parsed URI that we're wanting to connect to.
+
+ Raises:
+ ValueError if the environment variables contain an invalid proxy specification.
+ RuntimeError if no tls_options_factory is given for a https connection
"""
def __init__(
self,
reactor: IReactorCore,
+ proxy_reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
srv_resolver: SrvResolver,
parsed_uri: URI,
):
self._reactor = reactor
-
self._parsed_uri = parsed_uri
+ # http_proxy is not needed because federation is always over TLS
+ proxies = getproxies_environment()
+ https_proxy = proxies["https"].encode() if "https" in proxies else None
+ self.no_proxy = proxies["no"] if "no" in proxies else None
+
+ # endpoint and credentials to use to connect to the outbound https proxy, if any.
+ (
+ self._https_proxy_endpoint,
+ self._https_proxy_creds,
+ ) = proxyagent.http_proxy_endpoint(
+ https_proxy,
+ proxy_reactor,
+ tls_client_options_factory,
+ )
+
# set up the TLS connection params
#
# XXX disabling TLS is really only supported here for the benefit of the
@@ -273,9 +323,33 @@ class MatrixHostnameEndpoint:
host = server.host
port = server.port
+ should_skip_proxy = False
+ if self.no_proxy is not None:
+ should_skip_proxy = proxy_bypass_environment(
+ host.decode(),
+ proxies={"no": self.no_proxy},
+ )
+
+ endpoint: IStreamClientEndpoint
try:
- logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
- endpoint = HostnameEndpoint(self._reactor, host, port)
+ if self._https_proxy_endpoint and not should_skip_proxy:
+ logger.debug(
+ "Connecting to %s:%i via %s",
+ host.decode("ascii"),
+ port,
+ self._https_proxy_endpoint,
+ )
+ endpoint = HTTPConnectProxyEndpoint(
+ self._reactor,
+ self._https_proxy_endpoint,
+ host,
+ port,
+ proxy_creds=self._https_proxy_creds,
+ )
+ else:
+ logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
+ # not using a proxy
+ endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
result = await make_deferred_yieldable(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2efa15bf04..2e9898997c 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -59,7 +59,6 @@ from synapse.api.errors import (
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
- BlacklistingReactorWrapper,
BodyExceededMaxSize,
ByteWriteable,
encode_query_args,
@@ -69,7 +68,7 @@ from synapse.http.federation.matrix_federation_agent import MatrixFederationAgen
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
-from synapse.types import ISynapseReactor, JsonDict
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -325,13 +324,7 @@ class MatrixFederationHttpClient:
self.signing_key = hs.signing_key
self.server_name = hs.hostname
- # We need to use a DNS resolver which filters out blacklisted IP
- # addresses, to prevent DNS rebinding.
- self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
- hs.get_reactor(),
- hs.config.federation_ip_range_whitelist,
- hs.config.federation_ip_range_blacklist,
- )
+ self.reactor = hs.get_reactor()
user_agent = hs.version_string
if hs.config.user_agent_suffix:
@@ -342,6 +335,7 @@ class MatrixFederationHttpClient:
self.reactor,
tls_client_options_factory,
user_agent,
+ hs.config.federation_ip_range_whitelist,
hs.config.federation_ip_range_blacklist,
)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 19e987f118..a3f31452d0 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import base64
import logging
import re
from typing import Any, Dict, Optional, Tuple
@@ -21,7 +20,6 @@ from urllib.request import ( # type: ignore[attr-defined]
proxy_bypass_environment,
)
-import attr
from zope.interface import implementer
from twisted.internet import defer
@@ -38,7 +36,7 @@ from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
-from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@@ -46,22 +44,6 @@ logger = logging.getLogger(__name__)
_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
-@attr.s
-class ProxyCredentials:
- username_password = attr.ib(type=bytes)
-
- def as_proxy_authorization_value(self) -> bytes:
- """
- Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
-
- Returns:
- A transformation of the authentication string the encoded value for
- a Proxy-Authorization header.
- """
- # Encode as base64 and prepend the authorization type
- return b"Basic " + base64.encodebytes(self.username_password)
-
-
@implementer(IAgent)
class ProxyAgent(_AgentBase):
"""An Agent implementation which will use an HTTP proxy if one was requested
@@ -95,6 +77,7 @@ class ProxyAgent(_AgentBase):
Raises:
ValueError if use_proxy is set and the environment variables
contain an invalid proxy specification.
+ RuntimeError if no tls_options_factory is given for a https connection
"""
def __init__(
@@ -131,11 +114,11 @@ class ProxyAgent(_AgentBase):
https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None
- self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint(
+ self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint(
http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
)
- self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint(
+ self.https_proxy_endpoint, self.https_proxy_creds = http_proxy_endpoint(
https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
)
@@ -224,22 +207,12 @@ class ProxyAgent(_AgentBase):
and self.https_proxy_endpoint
and not should_skip_proxy
):
- connect_headers = Headers()
-
- # Determine whether we need to set Proxy-Authorization headers
- if self.https_proxy_creds:
- # Set a Proxy-Authorization header
- connect_headers.addRawHeader(
- b"Proxy-Authorization",
- self.https_proxy_creds.as_proxy_authorization_value(),
- )
-
endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor,
self.https_proxy_endpoint,
parsed_uri.host,
parsed_uri.port,
- headers=connect_headers,
+ self.https_proxy_creds,
)
else:
# not using a proxy
@@ -268,10 +241,10 @@ class ProxyAgent(_AgentBase):
)
-def _http_proxy_endpoint(
+def http_proxy_endpoint(
proxy: Optional[bytes],
reactor: IReactorCore,
- tls_options_factory: IPolicyForHTTPS,
+ tls_options_factory: Optional[IPolicyForHTTPS],
**kwargs,
) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy
@@ -294,6 +267,7 @@ def _http_proxy_endpoint(
Raise:
ValueError if proxy has no hostname or unsupported scheme.
+ RuntimeError if no tls_options_factory is given for a https connection
"""
if proxy is None:
return None, None
@@ -305,8 +279,13 @@ def _http_proxy_endpoint(
proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs)
if scheme == b"https":
- tls_options = tls_options_factory.creatorForNetloc(host, port)
- proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint)
+ if tls_options_factory:
+ tls_options = tls_options_factory.creatorForNetloc(host, port)
+ proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint)
+ else:
+ raise RuntimeError(
+ f"No TLS options for a https connection via proxy {proxy!s}"
+ )
return proxy_endpoint, credentials
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 0a19a333d7..5f0555039d 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -259,7 +259,9 @@ class DeleteMediaByID(RestServlet):
logging.info("Deleting local media by ID: %s", media_id)
- deleted_media, total = await self.media_repository.delete_local_media(media_id)
+ deleted_media, total = await self.media_repository.delete_local_media_ids(
+ [media_id]
+ )
return 200, {"deleted_media": deleted_media, "total": total}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index eef76ab18a..41f21ba118 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -172,7 +172,7 @@ class UserRestServletV2(RestServlet):
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(400, "Can only look up local users")
ret = await self.admin_handler.get_user(target_user)
@@ -796,7 +796,7 @@ class PushersRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(400, "Can only look up local users")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
@@ -811,10 +811,10 @@ class PushersRestServlet(RestServlet):
class UserMediaRestServlet(RestServlet):
"""
Gets information about all uploaded local media for a specific `user_id`.
+ With DELETE request you can delete all this media.
Example:
- http://localhost:8008/_synapse/admin/v1/users/
- @user:server/media
+ http://localhost:8008/_synapse/admin/v1/users/@user:server/media
Args:
The parameters `from` and `limit` are required for pagination.
@@ -830,6 +830,7 @@ class UserMediaRestServlet(RestServlet):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.media_repository = hs.get_media_repository()
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -840,7 +841,7 @@ class UserMediaRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(400, "Can only look up local users")
user = await self.store.get_user_by_id(user_id)
if user is None:
@@ -898,6 +899,73 @@ class UserMediaRestServlet(RestServlet):
return 200, ret
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.is_mine(UserID.from_string(user_id)):
+ raise SynapseError(400, "Can only look up local users")
+
+ user = await self.store.get_user_by_id(user_id)
+ if user is None:
+ raise NotFoundError("Unknown user")
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # If neither `order_by` nor `dir` is set, set the default order
+ # to newest media is on top for backward compatibility.
+ if b"order_by" not in request.args and b"dir" not in request.args:
+ order_by = MediaSortOrder.CREATED_TS.value
+ direction = "b"
+ else:
+ order_by = parse_string(
+ request,
+ "order_by",
+ default=MediaSortOrder.CREATED_TS.value,
+ allowed_values=(
+ MediaSortOrder.MEDIA_ID.value,
+ MediaSortOrder.UPLOAD_NAME.value,
+ MediaSortOrder.CREATED_TS.value,
+ MediaSortOrder.LAST_ACCESS_TS.value,
+ MediaSortOrder.MEDIA_LENGTH.value,
+ MediaSortOrder.MEDIA_TYPE.value,
+ MediaSortOrder.QUARANTINED_BY.value,
+ MediaSortOrder.SAFE_FROM_QUARANTINE.value,
+ ),
+ )
+ direction = parse_string(
+ request, "dir", default="f", allowed_values=("f", "b")
+ )
+
+ media, _ = await self.store.get_local_media_by_user_paginate(
+ start, limit, user_id, order_by, direction
+ )
+
+ deleted_media, total = await self.media_repository.delete_local_media_ids(
+ ([row["media_id"] for row in media])
+ )
+
+ return 200, {"deleted_media": deleted_media, "total": total}
+
class UserTokenRestServlet(RestServlet):
"""An admin API for logging in as a user.
@@ -1017,7 +1085,7 @@ class RateLimitRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
- raise SynapseError(400, "Can only lookup local users")
+ raise SynapseError(400, "Can only look up local users")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f887970b76..f1bc43be2d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -437,6 +437,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
+ state_events_at_start = []
for state_event in body["state_events_at_start"]:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -502,6 +503,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
)
event_id = event.event_id
+ state_events_at_start.append(event_id)
auth_event_ids.append(event_id)
events_to_create = body["events"]
@@ -651,7 +653,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
event_ids.append(base_insertion_event.event_id)
return 200, {
- "state_events": auth_event_ids,
+ "state_events": state_events_at_start,
"events": event_ids,
"next_chunk_id": insertion_event["content"][
EventContentFields.MSC2716_NEXT_CHUNK_ID
@@ -1445,6 +1447,46 @@ class RoomSpaceSummaryRestServlet(RestServlet):
)
+class RoomHierarchyRestServlet(RestServlet):
+ PATTERNS = (
+ re.compile(
+ "^/_matrix/client/unstable/org.matrix.msc2946"
+ "/rooms/(?P<room_id>[^/]*)/hierarchy$"
+ ),
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._auth = hs.get_auth()
+ self._space_summary_handler = hs.get_space_summary_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request, allow_guest=True)
+
+ max_depth = parse_integer(request, "max_depth")
+ if max_depth is not None and max_depth < 0:
+ raise SynapseError(
+ 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON
+ )
+
+ limit = parse_integer(request, "limit")
+ if limit is not None and limit <= 0:
+ raise SynapseError(
+ 400, "'limit' must be a positive integer", Codes.BAD_JSON
+ )
+
+ return 200, await self._space_summary_handler.get_room_hierarchy(
+ requester.user.to_string(),
+ room_id,
+ suggested_only=parse_boolean(request, "suggested_only", default=False),
+ max_depth=max_depth,
+ limit=limit,
+ from_token=parse_string(request, "from"),
+ )
+
+
def register_servlets(hs: "HomeServer", http_server, is_worker=False):
msc2716_enabled = hs.config.experimental.msc2716_enabled
@@ -1463,6 +1505,7 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomTypingRestServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
RoomSpaceSummaryRestServlet(hs).register(http_server)
+ RoomHierarchyRestServlet(hs).register(http_server)
RoomEventServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
RoomAliasListServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 4f702f890c..0f5ce41ff8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -836,7 +836,9 @@ class MediaRepository:
return {"deleted": deleted}
- async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]:
+ async def delete_local_media_ids(
+ self, media_ids: List[str]
+ ) -> Tuple[List[str], int]:
"""
Delete the given local or remote media ID from this server
@@ -845,7 +847,7 @@ class MediaRepository:
Returns:
A tuple of (list of deleted media IDs, total deleted media IDs).
"""
- return await self._remove_local_media_from_disk([media_id])
+ return await self._remove_local_media_from_disk(media_ids)
async def delete_old_local_media(
self,
|