diff --git a/README.rst b/README.rst
index 31d375d19b..2441b6a35c 100644
--- a/README.rst
+++ b/README.rst
@@ -195,7 +195,7 @@ By default Synapse uses SQLite in and doing so trades performance for convenienc
SQLite is only recommended in Synapse for testing purposes or for servers with
light workloads.
-Almost all installations should opt to use PostreSQL. Advantages include:
+Almost all installations should opt to use PostgreSQL. Advantages include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
diff --git a/changelog.d/7675.removal b/changelog.d/7675.removal
new file mode 100644
index 0000000000..2500e2c578
--- /dev/null
+++ b/changelog.d/7675.removal
@@ -0,0 +1 @@
+Deprecate `m.login.jwt` login method in favour of `org.matrix.login.jwt`, as `m.login.jwt` is not part of the Matrix spec.
diff --git a/changelog.d/7717.bugfix b/changelog.d/7717.bugfix
new file mode 100644
index 0000000000..bcbf146fea
--- /dev/null
+++ b/changelog.d/7717.bugfix
@@ -0,0 +1 @@
+Fix the tables ignored by `synapse_port_db` to be in sync the current database schema.
diff --git a/changelog.d/7718.feature b/changelog.d/7718.feature
new file mode 100644
index 0000000000..17071b9ea9
--- /dev/null
+++ b/changelog.d/7718.feature
@@ -0,0 +1 @@
+Media can now be marked as safe from quarantined.
diff --git a/changelog.d/7724.doc b/changelog.d/7724.doc
new file mode 100644
index 0000000000..909e0345c7
--- /dev/null
+++ b/changelog.d/7724.doc
@@ -0,0 +1 @@
+Corrected misspelling of PostgreSQL.
diff --git a/changelog.d/7725.misc b/changelog.d/7725.misc
new file mode 100644
index 0000000000..f295a45521
--- /dev/null
+++ b/changelog.d/7725.misc
@@ -0,0 +1 @@
+Speed up state res v2 across large state differences.
diff --git a/changelog.d/7727.misc b/changelog.d/7727.misc
new file mode 100644
index 0000000000..4d12d10fda
--- /dev/null
+++ b/changelog.d/7727.misc
@@ -0,0 +1 @@
+Convert directory handler to async/await.
diff --git a/changelog.d/7730.bugfix b/changelog.d/7730.bugfix
new file mode 100644
index 0000000000..9da254b56c
--- /dev/null
+++ b/changelog.d/7730.bugfix
@@ -0,0 +1 @@
+Fix missing `Content-Length` on HTTP responses from the metrics handler.
diff --git a/changelog.d/7735.bugfix b/changelog.d/7735.bugfix
new file mode 100644
index 0000000000..86959a5ca4
--- /dev/null
+++ b/changelog.d/7735.bugfix
@@ -0,0 +1 @@
+Fix large state resolutions from stalling Synapse for seconds at a time.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 810e08beb5..2eb795192f 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = {
"account_validity": ["email_sent"],
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
+ "local_media_repository": ["safe_from_quarantine"],
}
@@ -128,10 +129,20 @@ APPEND_ONLY_TABLES = [
IGNORED_TABLES = {
+ # We don't port these tables, as they're a faff and we can regenerate
+ # them anyway.
"user_directory",
"user_directory_search",
- "users_who_share_rooms",
- "users_in_pubic_room",
+ "user_directory_search_content",
+ "user_directory_search_docsize",
+ "user_directory_search_segdir",
+ "user_directory_search_segments",
+ "user_directory_search_stat",
+ "user_directory_search_pos",
+ "users_who_share_private_rooms",
+ "users_in_public_room",
+ # UI auth sessions have foreign keys so additional care needs to be taken,
+ # the sessions are transient anyway, so ignore them.
"ui_auth_sessions",
"ui_auth_sessions_credentials",
}
@@ -300,8 +311,6 @@ class Porter(object):
return
if table in IGNORED_TABLES:
- # We don't port these tables, as they're a faff and we can regenerate
- # them anyway.
self.progress.update(table, table_size) # Mark table as done
return
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index f2f16b1e43..79a2df6201 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -17,8 +17,6 @@ import logging
import string
from typing import Iterable, List, Optional
-from twisted.internet import defer
-
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
AuthError,
@@ -55,8 +53,7 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
- @defer.inlineCallbacks
- def _create_association(
+ async def _create_association(
self,
room_alias: RoomAlias,
room_id: str,
@@ -76,13 +73,13 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
- users = yield self.state.get_current_users_in_room(room_id)
+ users = await self.state.get_current_users_in_room(room_id)
servers = {get_domain_from_id(u) for u in users}
if not servers:
raise SynapseError(400, "Failed to get server list")
- yield self.store.create_room_alias_association(
+ await self.store.create_room_alias_association(
room_alias, room_id, servers, creator=creator
)
@@ -93,7 +90,7 @@ class DirectoryHandler(BaseHandler):
room_id: str,
servers: Optional[List[str]] = None,
check_membership: bool = True,
- ):
+ ) -> None:
"""Attempt to create a new alias
Args:
@@ -103,9 +100,6 @@ class DirectoryHandler(BaseHandler):
servers: Iterable of servers that others servers should try and join via
check_membership: Whether to check if the user is in the room
before the alias can be set (if the server's config requires it).
-
- Returns:
- Deferred
"""
user_id = requester.user.to_string()
@@ -148,7 +142,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
- can_create = await self.can_modify_alias(room_alias, user_id=user_id)
+ can_create = self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400,
@@ -158,7 +152,9 @@ class DirectoryHandler(BaseHandler):
await self._create_association(room_alias, room_id, servers, creator=user_id)
- async def delete_association(self, requester: Requester, room_alias: RoomAlias):
+ async def delete_association(
+ self, requester: Requester, room_alias: RoomAlias
+ ) -> str:
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
@@ -169,7 +165,7 @@ class DirectoryHandler(BaseHandler):
room_alias
Returns:
- Deferred[unicode]: room id that the alias used to point to
+ room id that the alias used to point to
Raises:
NotFoundError: if the alias doesn't exist
@@ -191,7 +187,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.")
- can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
+ can_delete = self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400,
@@ -208,8 +204,7 @@ class DirectoryHandler(BaseHandler):
return room_id
- @defer.inlineCallbacks
- def delete_appservice_association(
+ async def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias
):
if not service.is_interested_in_alias(room_alias.to_string()):
@@ -218,29 +213,27 @@ class DirectoryHandler(BaseHandler):
"This application service has not reserved this kind of alias",
errcode=Codes.EXCLUSIVE,
)
- yield self._delete_association(room_alias)
+ await self._delete_association(room_alias)
- @defer.inlineCallbacks
- def _delete_association(self, room_alias: RoomAlias):
+ async def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
- room_id = yield self.store.delete_room_alias(room_alias)
+ room_id = await self.store.delete_room_alias(room_alias)
return room_id
- @defer.inlineCallbacks
- def get_association(self, room_alias: RoomAlias):
+ async def get_association(self, room_alias: RoomAlias):
room_id = None
if self.hs.is_mine(room_alias):
- result = yield self.get_association_from_room_alias(room_alias)
+ result = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
servers = result.servers
else:
try:
- result = yield self.federation.make_query(
+ result = await self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
args={"room_alias": room_alias.to_string()},
@@ -265,7 +258,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
- users = yield self.state.get_current_users_in_room(room_id)
+ users = await self.state.get_current_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
@@ -277,13 +270,12 @@ class DirectoryHandler(BaseHandler):
return {"room_id": room_id, "servers": servers}
- @defer.inlineCallbacks
- def on_directory_query(self, args):
+ async def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
- result = yield self.get_association_from_room_alias(room_alias)
+ result = await self.get_association_from_room_alias(room_alias)
if result is not None:
return {"room_id": result.room_id, "servers": result.servers}
@@ -344,16 +336,15 @@ class DirectoryHandler(BaseHandler):
ratelimit=False,
)
- @defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias: RoomAlias):
- result = yield self.store.get_association_from_room_alias(room_alias)
+ async def get_association_from_room_alias(self, room_alias: RoomAlias):
+ result = await self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
- result = yield as_handler.query_room_alias_exists(room_alias)
+ result = await as_handler.query_room_alias_exists(room_alias)
return result
- def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
+ def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool:
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
@@ -366,12 +357,12 @@ class DirectoryHandler(BaseHandler):
for service in interested_services:
if user_id == service.sender:
# this user IS the app service so they can do whatever they like
- return defer.succeed(True)
+ return True
elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias.
- return defer.succeed(False)
+ return False
# either no interested services, or no service with an exclusive lock
- return defer.succeed(True)
+ return True
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
@@ -459,8 +450,7 @@ class DirectoryHandler(BaseHandler):
await self.store.set_room_is_public(room_id, making_public)
- @defer.inlineCallbacks
- def edit_published_appservice_room_list(
+ async def edit_published_appservice_room_list(
self, appservice_id: str, network_id: str, room_id: str, visibility: str
):
"""Add or remove a room from the appservice/network specific public
@@ -475,7 +465,7 @@ class DirectoryHandler(BaseHandler):
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
- yield self.store.set_room_is_public_appservice(
+ await self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 873f6bc39f..3828ff0ef0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -376,6 +376,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version_id(room_id)
state_map = await resolve_events_with_store(
+ self.clock,
room_id,
room_version,
state_maps,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f400b56c4f..70b5cb0f89 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -881,7 +881,9 @@ class EventCreationHandler(object):
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
- mapping = yield directory_handler.get_association(room_alias)
+ mapping = yield defer.ensureDeferred(
+ directory_handler.get_association(room_alias)
+ )
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index ab7f948ed4..4304c60d56 100644
--- a/synapse/metrics/_exposition.py
+++ b/synapse/metrics/_exposition.py
@@ -208,6 +208,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
raise
self.send_response(200)
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
+ self.send_header("Content-Length", str(len(output)))
self.end_headers()
self.wfile.write(output)
@@ -261,4 +262,6 @@ class MetricsResource(Resource):
def render_GET(self, request):
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
- return generate_latest(self.registry)
+ response = generate_latest(self.registry)
+ request.setHeader(b"Content-Length", str(len(response)))
+ return response
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 3efc825d15..af7e365e90 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -133,6 +133,8 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _update_badge(self):
+ # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
+ # to be largely redundant. perhaps we can remove it.
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index c2c9a9c3aa..bf0f9bd077 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -81,7 +81,8 @@ class LoginRestServlet(RestServlet):
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
- JWT_TYPE = "m.login.jwt"
+ JWT_TYPE = "org.matrix.login.jwt"
+ JWT_TYPE_DEPRECATED = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
@@ -116,6 +117,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
+ flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
if self.cas_enabled:
# we advertise CAS for backwards compat, though MSC1721 renamed it
@@ -149,6 +151,7 @@ class LoginRestServlet(RestServlet):
try:
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
+ or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 50fd843f66..495d9f04c8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
+from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
@@ -414,6 +415,7 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
+ self.clock,
event.room_id,
room_version,
state_set_ids,
@@ -516,6 +518,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
+ self.clock,
room_id,
room_version,
list(state_groups_ids.values()),
@@ -589,6 +592,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store(
+ clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
@@ -625,7 +629,7 @@ def resolve_events_with_store(
)
else:
return v2.resolve_events_with_store(
- room_id, room_version, state_sets, event_map, state_res_store
+ clock, room_id, room_version, state_sets, event_map, state_res_store
)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e25bc5d264..7181ecda9a 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -27,12 +27,20 @@ from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import StateMap
+from synapse.util import Clock
logger = logging.getLogger(__name__)
+# We want to yield to the reactor occasionally during state res when dealing
+# with large data sets, so that we don't exhaust the reactor. This is done by
+# yielding to reactor during loops every N iterations.
+_YIELD_AFTER_ITERATIONS = 100
+
+
@defer.inlineCallbacks
def resolve_events_with_store(
+ clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
@@ -42,13 +50,11 @@ def resolve_events_with_store(
"""Resolves the state using the v2 state resolution algorithm
Args:
+ clock
room_id: the room we are working in
-
room_version: The room version
-
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
-
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
@@ -113,7 +119,7 @@ def resolve_events_with_store(
)
sorted_power_events = yield _reverse_topological_power_sort(
- room_id, power_events, event_map, state_res_store, full_conflicted_set
+ clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
@@ -133,15 +139,16 @@ def resolve_events_with_store(
# OK, so we've now resolved the power events. Now sort the remaining
# events using the mainline of the resolved power level.
+ set_power_events = set(sorted_power_events)
leftover_events = [
- ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
+ ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
- room_id, leftover_events, pl, event_map, state_res_store
+ clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
@@ -316,12 +323,13 @@ def _add_event_and_auth_chain_to_graph(
@defer.inlineCallbacks
def _reverse_topological_power_sort(
- room_id, event_ids, event_map, state_res_store, auth_diff
+ clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
+ clock (Clock)
room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
@@ -333,18 +341,28 @@ def _reverse_topological_power_sort(
"""
graph = {}
- for event_id in event_ids:
+ for idx, event_id in enumerate(event_ids, start=1):
yield _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
event_to_pl = {}
- for event_id in graph:
+ for idx, event_id in enumerate(graph, start=1):
pl = yield _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
def _get_power_order(event_id):
ev = event_map[event_id]
pl = event_to_pl[event_id]
@@ -422,12 +440,13 @@ def _iterative_auth_checks(
@defer.inlineCallbacks
def _mainline_sort(
- room_id, event_ids, resolved_power_event_id, event_map, state_res_store
+ clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
+ clock (Clock)
room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
@@ -437,8 +456,14 @@ def _mainline_sort(
Returns:
Deferred[list[str]]: The sorted list
"""
+ if not event_ids:
+ # It's possible for there to be no event IDs here to sort, so we can
+ # skip calculating the mainline in that case.
+ return []
+
mainline = []
pl = resolved_power_event_id
+ idx = 0
while pl:
mainline.append(pl)
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
@@ -452,17 +477,29 @@ def _mainline_sort(
pl = aid
break
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
+ idx += 1
+
mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
event_ids = list(event_ids)
order_map = {}
- for ev_id in event_ids:
+ for idx, ev_id in enumerate(event_ids, start=1):
depth = yield _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
+ # We yield occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _YIELD_AFTER_ITERATIONS == 0:
+ yield clock.sleep(0)
+
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 8aecd414c2..15bc13cbd0 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -81,6 +81,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
+ def mark_local_media_as_safe(self, media_id: str):
+ """Mark a local media as safe from quarantining."""
+ return self.db.simple_update_one(
+ table="local_media_repository",
+ keyvalues={"media_id": media_id},
+ updatevalues={"safe_from_quarantine": True},
+ desc="mark_local_media_as_safe",
+ )
+
def get_url_cache(self, url, ts):
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 46f643c6b9..13e366536a 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -626,36 +626,10 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- total_media_quarantined = 0
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany(
- """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """,
- ((quarantined_by, media_id) for media_id in local_mxcs),
- )
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_mxcs
- ),
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
)
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
-
- return total_media_quarantined
-
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -805,17 +779,17 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
The total number of media items quarantined
"""
- total_media_quarantined = 0
-
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
- WHERE media_id = ?
+ WHERE media_id = ? AND safe_from_quarantine = ?
""",
- ((quarantined_by, media_id) for media_id in local_mxcs),
+ ((quarantined_by, media_id, False) for media_id in local_mxcs),
)
+ # Note that a rowcount of -1 can be used to indicate no rows were affected.
+ total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
txn.executemany(
"""
@@ -825,9 +799,7 @@ class RoomWorkerStore(SQLBaseStore):
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)
-
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
+ total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined
diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
new file mode 100644
index 0000000000..597f2ffd3d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
new file mode 100644
index 0000000000..69db89ac0e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 977615ebef..b1a4decced 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
return hs
+ def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
+ """Ensure a piece of media is quarantined when trying to access it."""
+ request, channel = self.make_request(
+ "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
+ )
+ request.render(self.download_resource)
+ self.pump(1.0)
+
+ # Should be quarantined
+ self.assertEqual(
+ 404,
+ int(channel.code),
+ msg=(
+ "Expected to receive a 404 on accessing quarantined media: %s"
+ % server_and_media_id
+ ),
+ )
+
def test_quarantine_media_requires_admin(self):
self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass")
@@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Attempt to access the media
- request, channel = self.make_request(
- "GET",
- server_name_and_media_id,
- shorthand=False,
- access_token=admin_user_tok,
- )
- request.render(self.download_resource)
- self.pump(1.0)
-
- # Should be quarantined
- self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_name_and_media_id
- ),
- )
+ self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
def test_quarantine_all_media_in_room(self, override_url_template=None):
self.register_user("room_admin", "pass", admin=True)
@@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_and_media_id_2 = mxc_2[6:]
# Test that we cannot download any of the media anymore
- request, channel = self.make_request(
- "GET",
- server_and_media_id_1,
- shorthand=False,
- access_token=non_admin_user_tok,
- )
- request.render(self.download_resource)
- self.pump(1.0)
-
- # Should be quarantined
- self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_and_media_id_1
- ),
- )
-
- request, channel = self.make_request(
- "GET",
- server_and_media_id_2,
- shorthand=False,
- access_token=non_admin_user_tok,
- )
- request.render(self.download_resource)
- self.pump(1.0)
-
- # Should be quarantined
- self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_and_media_id_2
- ),
- )
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
- def test_quaraantine_all_media_in_room_deprecated_api_path(self):
+ def test_quarantine_all_media_in_room_deprecated_api_path(self):
# Perform the above test with the deprecated API path
self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
@@ -449,25 +415,52 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Attempt to access each piece of media
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
+
+ def test_cannot_quarantine_safe_media(self):
+ self.register_user("user_admin", "pass", admin=True)
+ admin_user_tok = self.login("user_admin", "pass")
+
+ non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
+ non_admin_user_tok = self.login("user_nonadmin", "pass")
+
+ # Upload some media
+ response_1 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+ response_2 = self.helper.upload_media(
+ self.upload_resource, self.image_data, tok=non_admin_user_tok
+ )
+
+ # Extract media IDs
+ server_and_media_id_1 = response_1["content_uri"][6:]
+ server_and_media_id_2 = response_2["content_uri"][6:]
+
+ # Mark the second item as safe from quarantine.
+ _, media_id_2 = server_and_media_id_2.split("/")
+ self.get_success(self.store.mark_local_media_as_safe(media_id_2))
+
+ # Quarantine all media by this user
+ url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
+ non_admin_user
+ )
request, channel = self.make_request(
- "GET",
- server_and_media_id_1,
- shorthand=False,
- access_token=non_admin_user_tok,
+ "POST", url.encode("ascii"), access_token=admin_user_tok,
)
- request.render(self.download_resource)
+ self.render(request)
self.pump(1.0)
-
- # Should be quarantined
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
- 404,
- int(channel.code),
- msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
- % server_and_media_id_1,
- ),
+ json.loads(channel.result["body"].decode("utf-8")),
+ {"num_quarantined": 1},
+ "Expected 1 quarantined item",
)
+ # Attempt to access each piece of media, the first should fail, the
+ # second should succeed.
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+
# Attempt to access each piece of media
request, channel = self.make_request(
"GET",
@@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request.render(self.download_resource)
self.pump(1.0)
- # Should be quarantined
+ # Shouldn't be quarantined
self.assertEqual(
- 404,
+ 200,
int(channel.code),
msg=(
- "Expected to receive a 404 on accessing quarantined media: %s"
+ "Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 9033f09fd2..fd97999956 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -526,7 +526,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
return jwt.encode(token, secret, "HS256").decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@@ -568,7 +570,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["error"], "Invalid JWT")
def test_login_no_token(self):
- params = json.dumps({"type": "m.login.jwt"})
+ params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
self.assertEqual(channel.result["code"], b"401", channel.result)
@@ -640,7 +642,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return jwt.encode(token, secret, "RS256").decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index cdc347bc53..38f9b423ef 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -17,6 +17,8 @@ import itertools
import attr
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
@@ -41,6 +43,11 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
ORIGIN_SERVER_TS = 0
+class FakeClock:
+ def sleep(self, msec):
+ return defer.succeed(None)
+
+
class FakeEvent(object):
"""A fake event we use as a convenience.
@@ -417,6 +424,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
+ FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events],
@@ -565,6 +573,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
+ FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie],
|