diff --git a/changelog.d/10776.feature b/changelog.d/10776.feature
new file mode 100644
index 0000000000..aec0685a3d
--- /dev/null
+++ b/changelog.d/10776.feature
@@ -0,0 +1 @@
+Only allow the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send?chunk_id=xxx` endpoint to connect to an already existing insertion event.
diff --git a/changelog.d/10777.misc b/changelog.d/10777.misc
new file mode 100644
index 0000000000..aed78a16f5
--- /dev/null
+++ b/changelog.d/10777.misc
@@ -0,0 +1 @@
+Split out [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) meta events to their own fields in the `/batch_send` response.
diff --git a/changelog.d/10785.misc b/changelog.d/10785.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10785.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10810.bugfix b/changelog.d/10810.bugfix
new file mode 100644
index 0000000000..43e91f1f51
--- /dev/null
+++ b/changelog.d/10810.bugfix
@@ -0,0 +1 @@
+Fix a case where logging contexts would go missing when federation requests time out.
diff --git a/changelog.d/10812.misc b/changelog.d/10812.misc
new file mode 100644
index 0000000000..586a0b3a96
--- /dev/null
+++ b/changelog.d/10812.misc
@@ -0,0 +1 @@
+Use direct references to config flags.
diff --git a/changelog.d/10815.misc b/changelog.d/10815.misc
new file mode 100644
index 0000000000..fc2534dc14
--- /dev/null
+++ b/changelog.d/10815.misc
@@ -0,0 +1 @@
+Specify the type of token in generic "Invalid token" error messages.
\ No newline at end of file
diff --git a/changelog.d/10816.misc b/changelog.d/10816.misc
new file mode 100644
index 0000000000..2ca55b334a
--- /dev/null
+++ b/changelog.d/10816.misc
@@ -0,0 +1 @@
+Make `StateFilter` frozen so it is hashable.
diff --git a/changelog.d/10817.misc b/changelog.d/10817.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10817.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10823.misc b/changelog.d/10823.misc
new file mode 100644
index 0000000000..0532969900
--- /dev/null
+++ b/changelog.d/10823.misc
@@ -0,0 +1 @@
+Add type hints to the state database.
diff --git a/changelog.d/10834.misc b/changelog.d/10834.misc
new file mode 100644
index 0000000000..037695e6e9
--- /dev/null
+++ b/changelog.d/10834.misc
@@ -0,0 +1 @@
+Factor out PNG image data to a constant to be used in several tests.
diff --git a/mypy.ini b/mypy.ini
index 09ffdda1b9..b21e1555ab 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -60,6 +60,7 @@ files =
synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
+ synapse/storage/databases/state,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/keys.py,
@@ -86,10 +87,11 @@ files =
tests/handlers/test_sync.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
+ tests/storage/test_state.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
-[mypy-synapse.rest.client.*]
+[mypy-synapse.rest.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 05699714ee..e6ca9232ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -70,8 +70,8 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
- self._track_appservice_user_ips = hs.config.track_appservice_user_ips
- self._macaroon_secret_key = hs.config.macaroon_secret_key
+ self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
+ self._macaroon_secret_key = hs.config.key.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
async def check_user_in_room(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index e6bced93d5..a3b95f4de0 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -30,13 +30,15 @@ class AuthBlocking:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- self._server_notices_mxid = hs.config.server_notices_mxid
- self._hs_disabled = hs.config.hs_disabled
- self._hs_disabled_message = hs.config.hs_disabled_message
- self._admin_contact = hs.config.admin_contact
- self._max_mau_value = hs.config.max_mau_value
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
+ self._hs_disabled = hs.config.server.hs_disabled
+ self._hs_disabled_message = hs.config.server.hs_disabled_message
+ self._admin_contact = hs.config.server.admin_contact
+ self._max_mau_value = hs.config.server.max_mau_value
+ self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
+ self._mau_limits_reserved_threepids = (
+ hs.config.server.mau_limits_reserved_threepids
+ )
self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index c644b4dfc5..d310976fe3 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -102,7 +102,7 @@ class FederationPolicyForHTTPS:
self._config = config
# Check if we're using a custom list of a CA certificates
- trust_root = config.federation_ca_trust_root
+ trust_root = config.tls.federation_ca_trust_root
if trust_root is None:
# Use CA root certs provided by OpenSSL
trust_root = platformTrust()
@@ -113,7 +113,7 @@ class FederationPolicyForHTTPS:
# moving to TLS 1.2 by default, we want to respect the config option if
# it is set to 1.0 (which the alternate option, raiseMinimumTo, will not
# let us do).
- minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
+ minTLS = _TLS_VERSION_MAP[config.tls.federation_client_minimum_tls_version]
_verify_ssl = CertificateOptions(
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
@@ -125,10 +125,10 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
- self._should_verify = self._config.federation_verify_certificates
+ self._should_verify = self._config.tls.federation_verify_certificates
self._federation_certificate_verification_whitelist = (
- self._config.federation_certificate_verification_whitelist
+ self._config.tls.federation_certificate_verification_whitelist
)
def get_options(self, host: bytes):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 9e9b1c1c86..e1e13a2412 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -572,7 +572,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
- self.key_servers = self.config.key_servers
+ self.key_servers = self.config.key.key_servers
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 214ee948fa..638959cbec 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1237,7 +1237,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
- if not self.config.use_presence and edu_type == EduTypes.Presence:
+ if not self.config.server.use_presence and edu_type == EduTypes.Presence:
return
# Check if we have a handler on this instance
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4671ac0242..720d7bd74d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -594,7 +594,7 @@ class FederationSender(AbstractFederationSender):
destinations (list[str])
"""
- if not states or not self.hs.config.use_presence:
+ if not states or not self.hs.config.server.use_presence:
# No-op if presence is disabled.
return
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fbbf6fd834..3ea6270083 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1347,7 +1347,7 @@ class AuthHandler(BaseHandler):
try:
res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
- raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+ raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(res.user_id)
return res
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 4e8f7f1d85..0b24b40eb9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -413,7 +413,7 @@ class InitialSyncHandler(BaseHandler):
async def get_presence():
# If presence is disabled, return an empty list
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return []
states = await presence_handler.get_states(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 39b39cd3e2..4ab962a84b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -374,7 +374,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0]
- self._presence_enabled = hs.config.use_presence
+ self._presence_enabled = hs.config.server.use_presence
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
@@ -584,7 +584,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id = target_user.to_string()
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@@ -601,7 +601,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@@ -618,7 +618,7 @@ class PresenceHandler(BasePresenceHandler):
self.server_name = hs.hostname
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
- self._presence_enabled = hs.config.use_presence
+ self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry()
@@ -916,7 +916,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
user_id = user.to_string()
@@ -949,7 +949,7 @@ class PresenceHandler(BasePresenceHandler):
"""
# Override if it should affect the user's presence, if presence is
# disabled.
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
affect_presence = False
if affect_presence:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index edfdb99cbd..7523d8e839 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1090,7 +1090,7 @@ class SyncHandler:
block_all_presence_data = (
since_token is None and sync_config.filter_collection.blocks_all_presence()
)
- if self.hs_config.use_presence and not block_all_presence_data:
+ if self.hs_config.server.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
sync_result_builder,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index c2ea51ee16..5204c3d08c 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -321,8 +321,11 @@ class SimpleHttpClient:
self.user_agent = hs.version_string
self.clock = hs.get_clock()
- if hs.config.user_agent_suffix:
- self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+ if hs.config.server.user_agent_suffix:
+ self.user_agent = "%s %s" % (
+ self.user_agent,
+ hs.config.server.user_agent_suffix,
+ )
# We use this for our body producers to ensure that they use the correct
# reactor.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2e9898997c..ef10ec0937 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -66,7 +66,7 @@ from synapse.http.client import (
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging import opentracing
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
@@ -553,20 +553,29 @@ class MatrixFederationHttpClient:
with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent.
- request_deferred = self.agent.request(
+
+ # To preserve the logging context, the timeout is treated
+ # in a similar way to `defer.gatherResults`:
+ # * Each logging context-preserving fork is wrapped in
+ # `run_in_background`. In this case there is only one,
+ # since the timeout fork is not logging-context aware.
+ # * The `Deferred` that joins the forks back together is
+ # wrapped in `make_deferred_yieldable` to restore the
+ # logging context regardless of the path taken.
+ request_deferred = run_in_background(
+ self.agent.request,
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
-
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.reactor,
)
- response = await request_deferred
+ response = await make_deferred_yieldable(request_deferred)
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 36aabd8422..065948f982 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -365,7 +365,7 @@ class HttpPusher(Pusher):
if event.type == "m.room.member" and event.is_state():
d["notification"]["membership"] = event.content["membership"]
d["notification"]["user_is_target"] = event.state_key == self.user_id
- if self.hs.config.push_include_content and event.content:
+ if self.hs.config.push.push_include_content and event.content:
d["notification"]["content"] = event.content
# We no longer send aliases separately, instead, we send the human
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b89c6e6f2b..e38e3c5d44 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -110,7 +110,7 @@ class Mailer:
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name
- self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
+ self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
logger.info("Created Mailer for app_name %s" % app_name)
@@ -796,8 +796,8 @@ class Mailer:
Returns:
A link to open a room in the web client.
"""
- if self.hs.config.email_riot_base_url:
- base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
+ if self.hs.config.email.email_riot_base_url:
+ base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
@@ -815,9 +815,9 @@ class Mailer:
Returns:
A link to open the notification in the web client.
"""
- if self.hs.config.email_riot_base_url:
+ if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % (
- self.hs.config.email_riot_base_url,
+ self.hs.config.email.email_riot_base_url,
notif["room_id"],
notif["event_id"],
)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 021275437c..29ed346d37 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -35,12 +35,12 @@ class PusherFactory:
"http": HttpPusher
}
- logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
- if hs.config.email_enable_notifs:
+ logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs)
+ if hs.config.email.email_enable_notifs:
self.mailers: Dict[str, Mailer] = {}
- self._notif_template_html = hs.config.email_notif_template_html
- self._notif_template_text = hs.config.email_notif_template_text
+ self._notif_template_html = hs.config.email.email_notif_template_html
+ self._notif_template_text = hs.config.email.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index a1436f3930..26735447a6 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool:
self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID.
- self._pusher_shard_config = hs.config.push.pusher_shard_config
+ self._pusher_shard_config = hs.config.worker.pusher_shard_config
self._instance_name = hs.get_instance_name()
self._should_start_pushers = (
self._instance_name in self._pusher_shard_config.instances
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3adc576124..e04af705eb 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import JsonResource
+from typing import TYPE_CHECKING
+
+from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
from synapse.rest.client import (
account,
@@ -57,6 +59,9 @@ from synapse.rest.client import (
voip,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
@@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
- def register_servlets(client_resource, hs):
+ def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource)
# Deprecated in r0
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 5715190a78..a6fa03c90f 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore()
async def on_GET(
- self, request: SynapseRequest, user_id, device_id: str
+ self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index f5a38c2670..19f84f33f2 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, json_resource: HttpServer):
+ def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index c1a1ba645e..681e491826 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {}
self.hs = hs
- def _clear_old_nonces(self):
+ def _clear_old_nonces(self) -> None:
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index ed96978448..d466edeec2 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,6 +14,7 @@
import logging
import re
+from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, List, Tuple
from twisted.web.server import Request
@@ -179,7 +180,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
if not requester.app_service:
raise AuthError(
- 403,
+ HTTPStatus.FORBIDDEN,
"Only application services can use the /batchsend endpoint",
)
@@ -192,7 +193,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
if prev_events_from_query is None:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"prev_event query parameter is required when inserting historical messages back in time",
errcode=Codes.MISSING_PARAM,
)
@@ -213,7 +214,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
- state_events_at_start = []
+ state_event_ids_at_start = []
for state_event in body["state_events_at_start"]:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -279,7 +280,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
event_id = event.event_id
- state_events_at_start.append(event_id)
+ state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
events_to_create = body["events"]
@@ -299,7 +300,18 @@ class RoomBatchSendEventRestServlet(RestServlet):
# event, which causes the HS to ask for the state at the start of
# the chunk later.
prev_event_ids = [fake_prev_event_id]
- # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+
+ # Verify the chunk_id_from_query corresponds to an actual insertion event
+ # and have the chunk connected.
+ corresponding_insertion_event_id = (
+ await self.store.get_insertion_event_by_chunk_id(chunk_id_from_query)
+ )
+ if corresponding_insertion_event_id is None:
+ raise SynapseError(
+ 400,
+ "No insertion event corresponds to the given ?chunk_id",
+ errcode=Codes.INVALID_PARAM,
+ )
pass
# Otherwise, create an insertion event to act as a starting point.
#
@@ -424,20 +436,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
context=context,
)
- # Add the base_insertion_event to the bottom of the list we return
- if base_insertion_event is not None:
- event_ids.append(base_insertion_event.event_id)
+ insertion_event_id = event_ids[0]
+ chunk_event_id = event_ids[-1]
+ historical_event_ids = event_ids[1:-1]
- return 200, {
- "state_events": state_events_at_start,
- "events": event_ids,
+ response_dict = {
+ "state_event_ids": state_event_ids_at_start,
+ "event_ids": historical_event_ids,
"next_chunk_id": insertion_event["content"][
EventContentFields.MSC2716_NEXT_CHUNK_ID
],
+ "insertion_event_id": insertion_event_id,
+ "chunk_event_id": chunk_event_id,
}
+ if base_insertion_event is not None:
+ response_dict["base_insertion_event_id"] = base_insertion_event.event_id
+
+ return HTTPStatus.OK, response_dict
def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
- return 501, "Not implemented"
+ return HTTPStatus.NOT_IMPLEMENTED, "Not implemented"
def on_PUT(
self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 11f7320832..06e0fbde22 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -17,17 +17,22 @@ import logging
from hashlib import sha256
from http import HTTPStatus
from os import path
-from typing import Dict, List
+from typing import TYPE_CHECKING, Any, Dict, List
import jinja2
from jinja2 import TemplateNotFound
+from twisted.web.server import Request
+
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en"
@@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user.
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): homeserver
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
- async def _async_render_GET(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
if not public_version:
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- async def _async_render_POST(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_POST(self, request: Request) -> None:
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)
@@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("success.html not found")
- def _render_template(self, request, template_name, **template_args):
+ def _render_template(
+ self, request: Request, template_name: str, **template_args: Any
+ ) -> None:
# get_template checks for ".." so we don't need to worry too much
# about path traversal here.
template_html = self._jinja_env.get_template(
@@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args)
respond_with_html(request, 200, html)
- def _check_hash(self, userid, userhmac):
+ def _check_hash(self, userid: str, userhmac: bytes) -> None:
"""
Args:
- userid (unicode):
- userhmac (bytes):
+ userid:
+ userhmac:
Raises:
SynapseError if the hash doesn't match
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
index 4487b54abf..78df7af2cf 100644
--- a/synapse/rest/health.py
+++ b/synapse/rest/health.py
@@ -13,6 +13,7 @@
# limitations under the License.
from twisted.web.resource import Resource
+from twisted.web.server import Request
class HealthResource(Resource):
@@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain")
return b"OK"
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index c6c63073ea..7f8c1de1ff 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class KeyApiV2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 25f6eb842f..ebe243bcfd 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -58,18 +63,18 @@ class LocalKey(Resource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
- def update_response_body(self, time_now_msec):
+ def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
- def response_json_object(self):
+ def response_json_object(self) -> JsonDict:
verify_keys = {}
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
@@ -94,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 744360e5fd..d8fd7938a4 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,17 +13,23 @@
# limitations under the License.
import logging
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json
+from twisted.web.server import Request
+
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
@@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
+ assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
@@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: Request) -> None:
content = parse_json_object_from_request(request)
query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ async def query_keys(
+ self,
+ request: Request,
+ query: JsonDict,
+ query_remote_on_cache_miss: bool = False,
+ ) -> None:
logger.info("Handling query for keys %r", query)
store_queries = []
@@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource):
# Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), results in cached.items():
- results = [(result["ts_added_ms"], result) for result in results]
+ for (server_name, key_id, _), key_results in cached.items():
+ results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0
@@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json)
- results = {"server_keys": signed_keys}
+ response = {"server_keys": signed_keys}
- respond_with_json(request, 200, results, canonical_json=True)
+ respond_with_json(request, 200, response, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 90364ebcf7..7c881f2bdb 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,7 +16,10 @@
import logging
import os
import urllib
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple
+from types import TracebackType
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+
+import attr
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
@@ -120,7 +123,7 @@ def add_file_headers(
upload_name: The name of the requested file, if any.
"""
- def _quote(x):
+ def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types.
@@ -280,51 +283,74 @@ class Responder:
"""
pass
- def __enter__(self):
+ def __enter__(self) -> None:
pass
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
pass
-class FileInfo:
- """Details about a requested/uploaded file.
-
- Attributes:
- server_name (str): The server name where the media originated from,
- or None if local.
- file_id (str): The local ID of the file. For local files this is the
- same as the media_id
- url_cache (bool): If the file is for the url preview cache
- thumbnail (bool): Whether the file is a thumbnail or not.
- thumbnail_width (int)
- thumbnail_height (int)
- thumbnail_method (str)
- thumbnail_type (str): Content type of thumbnail, e.g. image/png
- thumbnail_length (int): The size of the media file, in bytes.
- """
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThumbnailInfo:
+ """Details about a generated thumbnail."""
- def __init__(
- self,
- server_name,
- file_id,
- url_cache=False,
- thumbnail=False,
- thumbnail_width=None,
- thumbnail_height=None,
- thumbnail_method=None,
- thumbnail_type=None,
- thumbnail_length=None,
- ):
- self.server_name = server_name
- self.file_id = file_id
- self.url_cache = url_cache
- self.thumbnail = thumbnail
- self.thumbnail_width = thumbnail_width
- self.thumbnail_height = thumbnail_height
- self.thumbnail_method = thumbnail_method
- self.thumbnail_type = thumbnail_type
- self.thumbnail_length = thumbnail_length
+ width: int
+ height: int
+ method: str
+ # Content type of thumbnail, e.g. image/png
+ type: str
+ # The size of the media file, in bytes.
+ length: Optional[int] = None
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FileInfo:
+ """Details about a requested/uploaded file."""
+
+ # The server name where the media originated from, or None if local.
+ server_name: Optional[str]
+ # The local ID of the file. For local files this is the same as the media_id
+ file_id: str
+ # If the file is for the url preview cache
+ url_cache: bool = False
+ # Whether the file is a thumbnail or not.
+ thumbnail: Optional[ThumbnailInfo] = None
+
+ # The below properties exist to maintain compatibility with third-party modules.
+ @property
+ def thumbnail_width(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.width
+
+ @property
+ def thumbnail_height(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.height
+
+ @property
+ def thumbnail_method(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.method
+
+ @property
+ def thumbnail_type(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.type
+
+ @property
+ def thumbnail_length(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 09531ebf54..39bbe4e874 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,7 +16,7 @@
import functools
import os
import re
-from typing import Callable, List
+from typing import Any, Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
"""
@functools.wraps(func)
- def _wrapped(self, *args, **kwargs):
+ def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path)
@@ -129,7 +129,7 @@ class MediaFilePaths:
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
- ):
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0f5ce41ff8..50e4c9e29f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -32,6 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
+from synapse.config.repository import ThumbnailRequirement
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
@@ -42,6 +44,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
Responder,
+ ThumbnailInfo,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -113,7 +116,7 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
- def _start_update_recently_accessed(self):
+ def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
)
@@ -210,7 +213,7 @@ class MediaRepository:
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
- file_info = FileInfo(None, media_id, url_cache=url_cache)
+ file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
@@ -468,7 +471,9 @@ class MediaRepository:
return media_info
- def _get_thumbnail_requirements(self, media_type):
+ def _get_thumbnail_requirements(
+ self, media_type: str
+ ) -> Tuple[ThumbnailRequirement, ...]:
scpos = media_type.find(";")
if scpos > 0:
media_type = media_type[:scpos]
@@ -514,7 +519,7 @@ class MediaRepository:
t_height: int,
t_method: str,
t_type: str,
- url_cache: Optional[str],
+ url_cache: bool,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
@@ -548,11 +553,12 @@ class MediaRepository:
server_name=None,
file_id=media_id,
url_cache=url_cache,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -585,7 +591,7 @@ class MediaRepository:
t_type: str,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
- FileInfo(server_name, file_id, url_cache=False)
+ FileInfo(server_name, file_id)
)
try:
@@ -616,11 +622,12 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -742,12 +749,13 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
url_cache=url_cache,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 56cdc1b4ed..01fada8fb5 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -15,7 +15,20 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+from types import TracebackType
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ BinaryIO,
+ Callable,
+ Generator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+)
import attr
@@ -83,12 +96,14 @@ class MediaStorage:
return fname
- async def write_to_file(self, source: IO, output: IO):
+ async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
- def store_into_file(self, file_info: FileInfo):
+ def store_into_file(
+ self, file_info: FileInfo
+ ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -125,7 +140,7 @@ class MediaStorage:
try:
with open(fname, "wb") as f:
- async def finish():
+ async def finish() -> None:
# Ensure that all writes have been flushed and close the
# file.
f.flush()
@@ -176,9 +191,9 @@ class MediaStorage:
self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
)
@@ -220,9 +235,9 @@ class MediaStorage:
legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
if os.path.exists(legacy_local_path):
@@ -255,10 +270,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.url_cache_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
@@ -267,10 +282,10 @@ class MediaStorage:
return self.filepaths.remote_media_thumbnail_rel(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.remote_media_filepath_rel(
file_info.server_name, file_info.file_id
@@ -279,10 +294,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.local_media_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.local_media_filepath_rel(file_info.file_id)
@@ -315,7 +330,12 @@ class FileResponder(Responder):
FileSender().beginFileTransfer(self.open_file, consumer)
)
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
self.open_file.close()
@@ -339,7 +359,7 @@ class ReadableFileWrapper:
clock = attr.ib(type=Clock)
path = attr.ib(type=str)
- async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f108da05db..fe0627d9b0 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,6 +27,7 @@ from urllib import parse as urlparse
import attr
+from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
@@ -473,7 +474,7 @@ class PreviewUrlResource(DirectServeJsonResource):
etag=etag,
)
- def _start_expire_url_cache_data(self):
+ def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
)
@@ -782,7 +783,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
def _iterate_over_text(
- tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+ tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 0ff6ad3c0c..6c9969e55f 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider):
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
- async def store():
+ async def store() -> None:
try:
return await maybe_awaitable(
self.backend.store_file(path, file_info)
@@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
- def __str__(self):
+ def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None:
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 12bd745cb2..22f43d8531 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -26,6 +26,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
+ ThumbnailInfo,
parse_media_id,
respond_404,
respond_with_file,
@@ -114,7 +115,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_id,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
server_name=None,
)
@@ -149,11 +150,12 @@ class ThumbnailResource(DirectServeJsonResource):
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -173,7 +175,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height,
desired_method,
desired_type,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
)
if file_path:
@@ -210,11 +212,12 @@ class ThumbnailResource(DirectServeJsonResource):
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -271,7 +274,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_info["filesystem_id"],
- url_cache=None,
+ url_cache=False,
server_name=server_name,
)
@@ -285,7 +288,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos: List[Dict[str, Any]],
media_id: str,
file_id: str,
- url_cache: Optional[str] = None,
+ url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
@@ -299,7 +302,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos:
@@ -318,13 +321,16 @@ class ThumbnailResource(DirectServeJsonResource):
respond_404(request)
return
+ # The thumbnail property must exist.
+ assert file_info.thumbnail is not None
+
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
return
@@ -351,18 +357,18 @@ class ThumbnailResource(DirectServeJsonResource):
server_name,
file_id=file_id,
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
@@ -370,8 +376,8 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
else:
logger.info("Failed to find any generated thumbnails")
@@ -385,7 +391,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
- url_cache: Optional[str],
+ url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
@@ -398,7 +404,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
@@ -495,12 +501,13 @@ class ThumbnailResource(DirectServeJsonResource):
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- thumbnail_length=thumbnail_info["thumbnail_length"],
+ thumbnail=ThumbnailInfo(
+ width=thumbnail_info["thumbnail_width"],
+ height=thumbnail_info["thumbnail_height"],
+ type=thumbnail_info["thumbnail_type"],
+ method=thumbnail_info["thumbnail_method"],
+ length=thumbnail_info["thumbnail_length"],
+ ),
)
# No matching thumbnail was found.
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a65e9e1802..df54a40649 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -41,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod
- def set_limits(max_image_pixels: int):
+ def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index 67c1ed1f5f..1c1c7b3613 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generator
from twisted.web.server import Request
@@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._server_name = hs.hostname
self._consent_version = hs.config.consent.user_consent_version
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: Request):
+ async def _async_render_POST(self, request: Request) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index 36ba401656..81fec39659 100644
--- a/synapse/rest/synapse/client/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -13,16 +13,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class OIDCResource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index 7785f17e90..4f375cb74c 100644
--- a/synapse/rest/synapse/client/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,10 +31,10 @@ class OIDCCallbackResource(DirectServeHtmlResource):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self._oidc_handler.handle_oidc_callback(request)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# the auth response can be returned via an x-www-form-urlencoded form instead
# of GET params, as per
# https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d30b478b98..28ae083497 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, Generator, List, Tuple
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -27,6 +27,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_boolean, parse_string
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util.templates import build_jinja_env
if TYPE_CHECKING:
@@ -57,7 +58,7 @@ class AvailabilityCheckResource(DirectServeJsonResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request: Request):
+ async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]:
localpart = parse_string(request, "username", required=True)
session_id = get_username_mapping_session_cookie_from_request(request)
@@ -73,7 +74,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -104,7 +105,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: SynapseRequest):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None
diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 781ccb237c..3f247e6a2c 100644
--- a/synapse/rest/synapse/client/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -13,17 +13,21 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class SAML2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))
diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index b37c7083dc..64378ed57b 100644
--- a/synapse/rest/synapse/client/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
import saml2.metadata
from twisted.web.resource import Resource
+from twisted.web.server import Request
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SAML2MetadataResource(Resource):
@@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.sp_config = hs.config.saml2_sp_config
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
metadata_xml = saml2.metadata.create_metadata_string(
configfile=None, config=self.sp_config
)
diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index 774ccd870f..47d2a6a229 100644
--- a/synapse/rest/synapse/client/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -15,7 +15,10 @@
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -31,7 +34,7 @@ class SAML2ResponseResource(DirectServeHtmlResource):
self._saml_handler = hs.get_saml_handler()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
@@ -40,5 +43,5 @@ class SAML2ResponseResource(DirectServeHtmlResource):
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 6a66a88c53..c80a3a99aa 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,26 +13,26 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import set_cors_headers
+from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class WellKnownBuilder:
- """Utility to construct the well-known response
-
- Args:
- hs (synapse.server.HomeServer):
- """
-
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._config = hs.config
- def get_well_known(self):
+ def get_well_known(self) -> Optional[JsonDict]:
# if we don't have a public_baseurl, we can't help much here.
if self._config.server.public_baseurl is None:
return None
@@ -52,11 +52,11 @@ class WellKnownResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self._well_known_builder = WellKnownBuilder(hs)
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
set_cors_headers(request)
r = self._well_known_builder.get_well_known()
if not r:
diff --git a/synapse/server.py b/synapse/server.py
index 4777ef585d..637eb15b78 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -392,7 +392,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
- if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+ if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
return InsecureInterceptableContextFactory()
return RegularPolicyForHTTPS()
@@ -418,8 +418,8 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
return SimpleHttpClient(
self,
- ip_whitelist=self.config.ip_range_whitelist,
- ip_blacklist=self.config.ip_range_blacklist,
+ ip_whitelist=self.config.server.ip_range_whitelist,
+ ip_blacklist=self.config.server.ip_range_blacklist,
use_proxy=True,
)
@@ -801,18 +801,18 @@ class HomeServer(metaclass=abc.ABCMeta):
logger.info(
"Connecting to redis (host=%r port=%r) for external cache",
- self.config.redis_host,
- self.config.redis_port,
+ self.config.redis.redis_host,
+ self.config.redis.redis_port,
)
return lazyConnection(
hs=self,
- host=self.config.redis_host,
- port=self.config.redis_port,
+ host=self.config.redis.redis_host,
+ port=self.config.redis.redis_port,
password=self.config.redis.redis_password,
reconnect=True,
)
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
- return self.config.send_federation
+ return self.config.worker.send_federation
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1dc347f0c9..5c21402dea 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -61,6 +61,7 @@ from .registration import RegistrationStore
from .rejections import RejectionsStore
from .relations import RelationsStore
from .room import RoomStore
+from .room_batch import RoomBatchStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .session import SessionStore
@@ -81,6 +82,7 @@ class DataStore(
EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
+ RoomBatchStore,
RegistrationStore,
StreamStore,
ProfileStore,
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
new file mode 100644
index 0000000000..54fa361d3e
--- /dev/null
+++ b/synapse/storage/databases/main/room_batch.py
@@ -0,0 +1,36 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from synapse.storage._base import SQLBaseStore
+
+
+class RoomBatchStore(SQLBaseStore):
+ async def get_insertion_event_by_chunk_id(self, chunk_id: str) -> Optional[str]:
+ """Retrieve a insertion event ID.
+
+ Args:
+ chunk_id: The chunk ID of the insertion event to retrieve.
+
+ Returns:
+ The event_id of an insertion event, or None if there is no known
+ insertion event for the given insertion event.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="insertion_events",
+ keyvalues={"next_chunk_id": chunk_id},
+ retcol="event_id",
+ allow_none=True,
+ )
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c2891cb07f..eb1118d2cb 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,12 +13,20 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""
- def _count_state_group_hops_txn(self, txn, state_group):
+ def _count_state_group_hops_txn(
+ self, txn: LoggingTransaction, state_group: int
+ ) -> int:
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
+ next_group: Optional[int] = state_group
count = 0
while next_group:
@@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter: Optional[StateFilter] = None
- ):
+ self,
+ txn: LoggingTransaction,
+ groups: List[int],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all()
- results = {group: {} for group in groups}
+ results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
@@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""
for group in groups:
- args = [group]
+ args: List[Union[int, str]] = [group]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
- next_group = group
+ next_group: Optional[int] = group
while next_group:
# We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
allow_none=True,
)
+ # The results shouldn't be considered mutable.
return results
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
- async def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(
+ self, progress: dict, batch_size: int
+ ) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
max_group = rows[0][0]
- def reindex_txn(txn):
+ def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
@@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- (prev_group,) = txn.fetchone()
+ # There will be a result due to the coalesce.
+ (prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group
if prev_group:
@@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades.
continue
- prev_state = self._get_state_groups_from_groups_txn(
+ prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group]
)
- prev_state = prev_state[prev_group]
+ prev_state = prev_state_by_group[prev_group]
- curr_state = self._get_state_groups_from_groups_txn(
+ curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group]
)
- curr_state = curr_state[state_group]
+ curr_state = curr_state_by_group[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return result * BATCH_SIZE_SCALE_FACTOR
- async def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
+ async def _background_index_state(self, progress: dict, batch_size: int) -> int:
+ def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c24f..f1e3a27e63 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,43 +13,56 @@
# limitations under the License.
import logging
-from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+
+import attr
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _GetStateGroupDelta:
"""Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
+ us use the iterable flag when caching
"""
- __slots__ = []
+ prev_group: Optional[int]
+ delta_ids: Optional[StateMap[str]]
- def __len__(self):
+ def __len__(self) -> int:
return len(self.delta_ids) if self.delta_ids else 0
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups."""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
@@ -81,19 +94,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# We size the non-members cache to be smaller than the members cache as the
# vast majority of state in Matrix (today) is member events.
- self._state_group_cache = DictionaryCache(
+ self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000,
)
- self._state_group_members_cache = DictionaryCache(
+ self._state_group_members_cache: DictionaryCache[
+ int, StateKey, str
+ ] = DictionaryCache(
"*stateGroupMembersCache*",
500000,
)
- def get_max_state_group_txn(txn: Cursor):
+ def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- return txn.fetchone()[0]
+ return txn.fetchone()[0] # type: ignore
self._state_group_seq_gen = build_sequence_generator(
db_conn,
@@ -105,15 +120,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@cached(max_entries=10000, iterable=True)
- async def get_state_group_delta(self, state_group):
+ async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- (prev_group, delta_ids), where both may be None.
+ _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
"""
- def _get_state_group_delta_txn(txn):
+ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
prev_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
@@ -154,7 +169,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
- results = {}
+ results: Dict[int, StateMap[str]] = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
@@ -168,19 +183,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return results
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ def _get_state_for_group_using_cache(
+ self,
+ cache: DictionaryCache[int, StateKey, str],
+ group: int,
+ state_filter: StateFilter,
+ ) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ cache: the state group cache to use
+ group: The state group to lookup
+ state_filter: The state filter used to fetch state from the database.
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
+ Returns:
+ 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
"""
cache_entry = cache.get(group)
state_dict_ids = cache_entry.value
@@ -277,8 +297,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
def _get_state_for_groups_using_cache(
- self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
- ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+ self,
+ groups: Iterable[int],
+ cache: DictionaryCache[int, StateKey, str],
+ state_filter: StateFilter,
+ ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -310,21 +333,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
def _insert_into_cache(
self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
+ group_to_state_dict: Dict[int, StateMap[str]],
+ state_filter: StateFilter,
+ cache_seq_num_members: int,
+ cache_seq_num_non_members: int,
+ ) -> None:
"""Inserts results from querying the database into the relevant cache.
Args:
- group_to_state_dict (dict): The new entries pulled from database.
+ group_to_state_dict: The new entries pulled from database.
Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
+ state_filter: The state filter used to fetch state
from the database.
- cache_seq_num_members (int): Sequence number of member cache since
+ cache_seq_num_members: Sequence number of member cache since
last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
+ cache_seq_num_non_members: Sequence number of member cache since
last lookup in cache
"""
@@ -395,7 +418,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID
"""
- def _store_state_group_txn(txn):
+ def _store_state_group_txn(txn: LoggingTransaction) -> int:
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
@@ -426,6 +449,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ assert delta_ids is not None
+
self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
@@ -498,7 +523,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def purge_unreferenced_state_groups(
- self, room_id: str, state_groups_to_delete
+ self, room_id: str, state_groups_to_delete: Collection[int]
) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -506,8 +531,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Args:
room_id: The room the state groups belong to (must all be in the
same room).
- state_groups_to_delete (Collection[int]): Set of all state groups
- to delete.
+ state_groups_to_delete: Set of all state groups to delete.
"""
await self.db_pool.runInteraction(
@@ -517,7 +541,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ def _purge_unreferenced_state_groups(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
logger.info(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
@@ -546,8 +575,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# groups to non delta versions.
for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
+ curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state_by_group[sg]
self.db_pool.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
@@ -605,12 +634,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
- async def purge_room_state(self, room_id, state_groups_to_delete):
+ async def purge_room_state(
+ self, room_id: str, state_groups_to_delete: Collection[int]
+ ) -> None:
"""Deletes all record of a room from state tables
Args:
- room_id (str):
- state_groups_to_delete (list[int]): State groups to delete
+ room_id:
+ state_groups_to_delete: State groups to delete
"""
await self.db_pool.runInteraction(
@@ -620,7 +651,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ def _purge_room_state_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index c552dbf04c..10a46b5e82 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -73,7 +73,7 @@ class RelationPaginationToken:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
- raise SynapseError(400, "Invalid token")
+ raise SynapseError(400, "Invalid relation pagination token")
def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream)
@@ -103,7 +103,7 @@ class AggregationPaginationToken:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
- raise SynapseError(400, "Invalid token")
+ raise SynapseError(400, "Invalid aggregation pagination token")
def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e5400d681a..5e86befde4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,12 +25,15 @@ from typing import (
)
import attr
+from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
+ from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
+
from synapse.server import HomeServer
from synapse.storage.databases import Databases
@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
class StateFilter:
"""A filter used when querying for state.
@@ -53,14 +56,19 @@ class StateFilter:
appear in `types`.
"""
- types = attr.ib(type=Dict[str, Optional[Set[str]]])
+ types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {k: v for k, v in self.types.items() if v is not None}
+ # this is needed to work around the fact that StateFilter is frozen
+ object.__setattr__(
+ self,
+ "types",
+ frozendict({k: v for k, v in self.types.items() if v is not None}),
+ )
@staticmethod
def all() -> "StateFilter":
@@ -69,7 +77,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=True)
+ return StateFilter(types=frozendict(), include_others=True)
@staticmethod
def none() -> "StateFilter":
@@ -78,7 +86,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=False)
+ return StateFilter(types=frozendict(), include_others=False)
@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +111,12 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore
- return StateFilter(types=type_dict)
+ return StateFilter(
+ types=frozendict(
+ (k, frozenset(v) if v is not None else None)
+ for k, v in type_dict.items()
+ )
+ )
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +129,10 @@ class StateFilter:
Returns:
The new state filter
"""
- return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
+ return StateFilter(
+ types=frozendict({EventTypes.Member: frozenset(members)}),
+ include_others=True,
+ )
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +189,7 @@ class StateFilter:
# We want to return all non-members, but only particular
# memberships
return StateFilter(
- types={EventTypes.Member: self.types[EventTypes.Member]},
+ types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)
@@ -245,14 +261,15 @@ class StateFilter:
return len(self.concrete_types())
- def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
- """Returns the state filtered with by this StateFilter
+ def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
+ """Returns the state filtered with by this StateFilter.
Args:
state: The state map to filter
Returns:
- The filtered state map
+ The filtered state map.
+ This is a copy, so it's safe to mutate.
"""
if self.is_full():
return dict(state_dict)
@@ -324,14 +341,16 @@ class StateFilter:
if state_keys is None:
member_filter = StateFilter.all()
else:
- member_filter = StateFilter({EventTypes.Member: state_keys})
+ member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others:
member_filter = StateFilter.all()
else:
member_filter = StateFilter.none()
non_member_filter = StateFilter(
- types={k: v for k, v in self.types.items() if k != EventTypes.Member},
+ types=frozendict(
+ {k: v for k, v in self.types.items() if k != EventTypes.Member}
+ ),
include_others=self.include_others,
)
@@ -358,7 +377,8 @@ class StateGroupStorage:
make up the delta between the old and new state groups.
"""
- return await self.stores.state.get_state_group_delta(state_group)
+ state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+ return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
diff --git a/synapse/types.py b/synapse/types.py
index d4759b2dfd..90168ce8fa 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -511,7 +511,7 @@ class RoomStreamToken:
)
except Exception:
pass
- raise SynapseError(400, "Invalid token %r" % (string,))
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
def parse_stream_token(cls, string: str) -> "RoomStreamToken":
@@ -520,7 +520,7 @@ class RoomStreamToken:
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
- raise SynapseError(400, "Invalid token %r" % (string,))
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
"""Return a new token such that if an event is after both this token and
@@ -619,7 +619,7 @@ class StreamToken:
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
)
except Exception:
- raise SynapseError(400, "Invalid Token")
+ raise SynapseError(400, "Invalid stream token")
async def to_string(self, store: "DataStore") -> str:
return self._SEPARATOR.join(
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index ade088aae2..485ddb1893 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -130,7 +130,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
sequence: int,
key: KT,
value: Dict[DKT, DV],
- fetched_keys: Optional[Set[DKT]] = None,
+ fetched_keys: Optional[Iterable[DKT]] = None,
) -> None:
"""Updates the entry in the cache
@@ -155,7 +155,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(
- self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
+ self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index ac419f0db3..01b1b0d4a0 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
# limitations under the License.
import logging
import os
-from binascii import unhexlify
from typing import Optional, Tuple
from twisted.internet.protocol import Factory
@@ -28,6 +27,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+from tests.test_utils import SMALL_PNG
logger = logging.getLogger(__name__)
@@ -190,31 +190,25 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
- png_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
- request1.write(png_data)
+ request1.write(SMALL_PNG)
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
- self.assertEqual(channel1.result["body"], png_data)
+ self.assertEqual(channel1.result["body"], SMALL_PNG)
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
- request2.write(png_data)
+ request2.write(SMALL_PNG)
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
- self.assertEqual(channel2.result["body"], png_data)
+ self.assertEqual(channel2.result["body"], SMALL_PNG)
# We expect only three new thumbnails to have been persisted.
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index bfa638fb4b..febd40b656 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
import json
import os
import urllib.parse
-from binascii import unhexlify
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -28,6 +27,7 @@ from synapse.rest.client import groups, login, room
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
class VersionTestCase(unittest.HomeserverTestCase):
@@ -150,11 +150,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
- self.image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
def make_homeserver(self, reactor, clock):
@@ -266,7 +261,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=admin_user_tok
)
# Extract media ID from the response
@@ -314,10 +309,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract mxcs
@@ -381,10 +376,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract media IDs
@@ -421,10 +416,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract media IDs
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 972d60570c..2f02934e72 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -1,4 +1,5 @@
# Copyright 2020 Dirk Klimpel
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +15,6 @@
import json
import os
-from binascii import unhexlify
from parameterized import parameterized
@@ -25,6 +25,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@@ -110,15 +111,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
download_resource = self.media_repo.children[b"download"]
upload_resource = self.media_repo.children[b"upload"]
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -504,16 +500,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
Create a media and return media_id and server_and_media_id
"""
upload_resource = self.media_repo.children[b"upload"]
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -584,16 +574,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# Create media
upload_resource = media_repo.children[b"upload"]
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -711,16 +695,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
# Create media
upload_resource = media_repo.children[b"upload"]
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 5cd82209c4..ece89a65ac 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -1,4 +1,5 @@
# Copyright 2020 Dirk Klimpel
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,6 @@
# limitations under the License.
import json
-from binascii import unhexlify
from typing import Any, Dict, List, Optional
import synapse.rest.admin
@@ -21,6 +21,7 @@ from synapse.api.errors import Codes
from synapse.rest.client import login
from tests import unittest
+from tests.test_utils import SMALL_PNG
class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
@@ -468,16 +469,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
upload_resource = self.media_repo.children[b"upload"]
for _ in range(number_media):
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
# Upload some media into the room
self.helper.upload_media(
- upload_resource, image_data, tok=user_token, expect_code=200
+ upload_resource, SMALL_PNG, tok=user_token, expect_code=200
)
def _check_fields(self, content: List[Dict[str, Any]]):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ee204c404b..cc3f16c62a 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeSite, make_request
-from tests.test_utils import make_awaitable
+from tests.test_utils import SMALL_PNG, make_awaitable
from tests.unittest import override_config
@@ -2835,11 +2835,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
- image_data1 = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
+ image_data1 = SMALL_PNG
# Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
image_data2 = unhexlify(
b"47494638376101000100800100000000"
@@ -2943,14 +2939,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
media_ids = []
for _ in range(number_media):
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
- media_ids.append(self._create_media_and_access(user_token, image_data))
+ media_ids.append(self._create_media_and_access(user_token, SMALL_PNG))
return media_ids
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2f7eebfe69..9ea1c2bf25 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,6 +38,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
from tests.utils import default_config
@@ -134,11 +135,7 @@ class _TestImage:
# smoll png
(
_TestImage(
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- ),
+ SMALL_PNG,
b"image/png",
b".png",
unhexlify(
@@ -593,15 +590,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
def test_upload_innocent(self):
"""Attempt to upload some innocent data that should be allowed."""
-
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
self.helper.upload_media(
- self.upload_resource, image_data, tok=self.tok, expect_code=200
+ self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
def test_upload_ban(self):
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8695264595..32060f2abd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -14,6 +14,8 @@
import logging
+from frozendict import frozendict
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
@@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
- types={EventTypes.Member: {self.u_alice.to_string()}},
+ types=frozendict(
+ {EventTypes.Member: frozenset({self.u_alice.to_string()})}
+ ),
include_others=True,
),
)
@@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}),
+ include_others=True,
),
)
)
@@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=False
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=False,
),
)
@@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=False
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=False,
),
)
@@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=False
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=False,
),
)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index be6302d170..15ac2bfeba 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -1,5 +1,4 @@
-# Copyright 2019 New Vector Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,6 +18,7 @@ Utilities for running the unit tests
import sys
import warnings
from asyncio import Future
+from binascii import unhexlify
from typing import Any, Awaitable, Callable, TypeVar
from unittest.mock import Mock
@@ -117,3 +117,13 @@ class FakeResponse:
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
+
+
+# A small image used in some tests.
+#
+# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
+SMALL_PNG = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+)
|