diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml
index 808f825331..2bf32e376b 100644
--- a/.github/workflows/docs.yaml
+++ b/.github/workflows/docs.yaml
@@ -61,6 +61,5 @@ jobs:
uses: peaceiris/actions-gh-pages@068dc23d9710f1ba62e86896f84735d869951305 # v3.8.0
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- keep_files: true
publish_dir: ./book
destination_dir: ./${{ steps.vars.outputs.branch-version }}
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 8736699ad8..fa9c5e036a 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -192,6 +192,7 @@ jobs:
volumes:
- ${{ github.workspace }}:/src
env:
+ SYTEST_BRANCH: ${{ github.head_ref }}
POSTGRES: ${{ matrix.postgres && 1}}
MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}}
WORKERS: ${{ matrix.workers && 1 }}
diff --git a/CHANGES.md b/CHANGES.md
index 4e9cefe69c..652f4b7955 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,7 +1,23 @@
-Synapse 1.43.0rc1 (2021-09-14)
+Synapse 1.43.0 (2021-09-21)
+===========================
+
+This release drops support for the deprecated, unstable API for [MSC2858 (Multiple SSO Identity Providers)](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), as well as the undocumented `experimental.msc2858_enabled` config option. Client authors should update their clients to use the stable API, available since Synapse 1.30.
+
+The documentation has been updated with configuration for routing `/spaces`, `/hierarchy` and `/summary` to workers. See [the upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.43/docs/upgrade.md#upgrading-to-v1430) for more details.
+
+No significant changes since 1.43.0rc2.
+
+Synapse 1.43.0rc2 (2021-09-17)
==============================
-This release drops support for the deprecated, unstable API for [MSC2858](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), as well as the undocumented `experimental.msc2858_enabled` config option. Client authors should update their clients to use the stable API, available since Synapse 1.30.
+Bugfixes
+--------
+
+- Added opentracing logging to help debug [\#9424](https://github.com/matrix-org/synapse/issues/9424). ([\#10828](https://github.com/matrix-org/synapse/issues/10828))
+
+
+Synapse 1.43.0rc1 (2021-09-14)
+==============================
Features
--------
diff --git a/changelog.d/10659.misc b/changelog.d/10659.misc
new file mode 100644
index 0000000000..d677a521c3
--- /dev/null
+++ b/changelog.d/10659.misc
@@ -0,0 +1 @@
+Fix GitHub Actions config so we can run sytest on synapse from parallel branches.
\ No newline at end of file
diff --git a/changelog.d/10776.feature b/changelog.d/10776.feature
new file mode 100644
index 0000000000..aec0685a3d
--- /dev/null
+++ b/changelog.d/10776.feature
@@ -0,0 +1 @@
+Only allow the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send?chunk_id=xxx` endpoint to connect to an already existing insertion event.
diff --git a/changelog.d/10777.misc b/changelog.d/10777.misc
new file mode 100644
index 0000000000..aed78a16f5
--- /dev/null
+++ b/changelog.d/10777.misc
@@ -0,0 +1 @@
+Split out [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) meta events to their own fields in the `/batch_send` response.
diff --git a/changelog.d/10785.misc b/changelog.d/10785.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10785.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10796.misc b/changelog.d/10796.misc
new file mode 100644
index 0000000000..1873b2386a
--- /dev/null
+++ b/changelog.d/10796.misc
@@ -0,0 +1 @@
+Simplify the internal logic which maintains the user directory database tables.
\ No newline at end of file
diff --git a/changelog.d/10807.bugfix b/changelog.d/10807.bugfix
new file mode 100644
index 0000000000..be03f5c738
--- /dev/null
+++ b/changelog.d/10807.bugfix
@@ -0,0 +1 @@
+Allow sending a membership event to unban a user. Contributed by @aaronraimist.
\ No newline at end of file
diff --git a/changelog.d/10810.bugfix b/changelog.d/10810.bugfix
new file mode 100644
index 0000000000..43e91f1f51
--- /dev/null
+++ b/changelog.d/10810.bugfix
@@ -0,0 +1 @@
+Fix a case where logging contexts would go missing when federation requests time out.
diff --git a/changelog.d/10812.misc b/changelog.d/10812.misc
new file mode 100644
index 0000000000..586a0b3a96
--- /dev/null
+++ b/changelog.d/10812.misc
@@ -0,0 +1 @@
+Use direct references to config flags.
diff --git a/changelog.d/10814.feature b/changelog.d/10814.feature
new file mode 100644
index 0000000000..4fa95a6cc9
--- /dev/null
+++ b/changelog.d/10814.feature
@@ -0,0 +1 @@
+Improve oEmbed previews by processing the author name, photo, and video information.
diff --git a/changelog.d/10815.misc b/changelog.d/10815.misc
new file mode 100644
index 0000000000..fc2534dc14
--- /dev/null
+++ b/changelog.d/10815.misc
@@ -0,0 +1 @@
+Specify the type of token in generic "Invalid token" error messages.
\ No newline at end of file
diff --git a/changelog.d/10816.misc b/changelog.d/10816.misc
new file mode 100644
index 0000000000..2ca55b334a
--- /dev/null
+++ b/changelog.d/10816.misc
@@ -0,0 +1 @@
+Make `StateFilter` frozen so it is hashable.
diff --git a/changelog.d/10817.misc b/changelog.d/10817.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10817.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10823.misc b/changelog.d/10823.misc
new file mode 100644
index 0000000000..0532969900
--- /dev/null
+++ b/changelog.d/10823.misc
@@ -0,0 +1 @@
+Add type hints to the state database.
diff --git a/changelog.d/10828.bugfix b/changelog.d/10828.bugfix
deleted file mode 100644
index e00c10ec81..0000000000
--- a/changelog.d/10828.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Added opentrace logging to help debug #9424.
diff --git a/changelog.d/10829.misc b/changelog.d/10829.misc
new file mode 100644
index 0000000000..ac5fd6b047
--- /dev/null
+++ b/changelog.d/10829.misc
@@ -0,0 +1 @@
+Track cache eviction rates more finely in Prometheus' monitoring.
\ No newline at end of file
diff --git a/changelog.d/10831.misc b/changelog.d/10831.misc
new file mode 100644
index 0000000000..f09af2e00a
--- /dev/null
+++ b/changelog.d/10831.misc
@@ -0,0 +1 @@
+Add missing type hints to handlers.
diff --git a/changelog.d/10834.misc b/changelog.d/10834.misc
new file mode 100644
index 0000000000..037695e6e9
--- /dev/null
+++ b/changelog.d/10834.misc
@@ -0,0 +1 @@
+Factor out PNG image data to a constant to be used in several tests.
diff --git a/changelog.d/10835.misc b/changelog.d/10835.misc
new file mode 100644
index 0000000000..0c3d13477e
--- /dev/null
+++ b/changelog.d/10835.misc
@@ -0,0 +1 @@
+Add a test to ensure state events sent by modules get persisted correctly.
diff --git a/changelog.d/10838.misc b/changelog.d/10838.misc
new file mode 100644
index 0000000000..b1977d0a2e
--- /dev/null
+++ b/changelog.d/10838.misc
@@ -0,0 +1 @@
+Rename [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) fields and event types from `chunk` to `batch` to match the `/batch_send` endpoint.
diff --git a/changelog.d/10839.misc b/changelog.d/10839.misc
new file mode 100644
index 0000000000..d0e10f31d5
--- /dev/null
+++ b/changelog.d/10839.misc
@@ -0,0 +1 @@
+Rename [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` query parameter from `?prev_event` to more obvious usage with `?prev_event_id`.
diff --git a/changelog.d/10843.bugfix b/changelog.d/10843.bugfix
new file mode 100644
index 0000000000..5027a1dbef
--- /dev/null
+++ b/changelog.d/10843.bugfix
@@ -0,0 +1 @@
+Fix a bug causing the `remove_stale_pushers` background job to repeatedly fail and log errors. This bug affected Synapse servers that had been upgraded from version 1.28 or older and are using SQLite.
diff --git a/changelog.d/10845.doc b/changelog.d/10845.doc
new file mode 100644
index 0000000000..a13c845ae6
--- /dev/null
+++ b/changelog.d/10845.doc
@@ -0,0 +1 @@
+Fix some crashes in the Module API example code, by adding JSON encoding/decoding.
diff --git a/changelog.d/10856.misc b/changelog.d/10856.misc
new file mode 100644
index 0000000000..f09af2e00a
--- /dev/null
+++ b/changelog.d/10856.misc
@@ -0,0 +1 @@
+Add missing type hints to handlers.
diff --git a/changelog.d/10859.bugfix b/changelog.d/10859.bugfix
new file mode 100644
index 0000000000..c1bfe22d54
--- /dev/null
+++ b/changelog.d/10859.bugfix
@@ -0,0 +1 @@
+Fix a bug in Unicode support of the room search admin API. It is now possible to search for rooms with non-ASCII characters.
\ No newline at end of file
diff --git a/changelog.d/10867.misc b/changelog.d/10867.misc
new file mode 100644
index 0000000000..01e51fbc6e
--- /dev/null
+++ b/changelog.d/10867.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.http.site`.
diff --git a/changelog.d/10869.doc b/changelog.d/10869.doc
new file mode 100644
index 0000000000..c117386072
--- /dev/null
+++ b/changelog.d/10869.doc
@@ -0,0 +1 @@
+Properly remove deleted files from GitHub pages when generating the documentation.
diff --git a/changelog.d/10879.misc b/changelog.d/10879.misc
new file mode 100644
index 0000000000..acc04930fa
--- /dev/null
+++ b/changelog.d/10879.misc
@@ -0,0 +1 @@
+Include outlier status when we log V2 or V3 events.
diff --git a/debian/changelog b/debian/changelog
index d3a6f6a4e6..4b07d04128 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,15 @@
+matrix-synapse-py3 (1.43.0) stable; urgency=medium
+
+ * New synapse release 1.43.0.
+
+ -- Synapse Packaging team <packages@matrix.org> Tue, 21 Sep 2021 11:49:05 +0100
+
+matrix-synapse-py3 (1.43.0~rc2) stable; urgency=medium
+
+ * New synapse release 1.43.0~rc2.
+
+ -- Synapse Packaging team <packages@matrix.org> Fri, 17 Sep 2021 10:43:21 +0100
+
matrix-synapse-py3 (1.43.0~rc1) stable; urgency=medium
* New synapse release 1.43.0~rc1.
diff --git a/docs/development/url_previews.md b/docs/development/url_previews.md
index bbe05e281c..aff3813609 100644
--- a/docs/development/url_previews.md
+++ b/docs/development/url_previews.md
@@ -25,16 +25,14 @@ When Synapse is asked to preview a URL it does the following:
3. Kicks off a background process to generate a preview:
1. Checks the database cache by URL and timestamp and returns the result if it
has not expired and was successful (a 2xx return code).
- 2. Checks if the URL matches an oEmbed pattern. If it does, fetch the oEmbed
- response. If this is an image, replace the URL to fetch and continue. If
- if it is HTML content, use the HTML as the document and continue.
- 3. If it doesn't match an oEmbed pattern, downloads the URL and stores it
- into a file via the media storage provider and saves the local media
- metadata.
- 5. If the media is an image:
+ 2. Checks if the URL matches an [oEmbed](https://oembed.com/) pattern. If it
+ does, update the URL to download.
+ 3. Downloads the URL and stores it into a file via the media storage provider
+ and saves the local media metadata.
+ 4. If the media is an image:
1. Generates thumbnails.
2. Generates an Open Graph response based on image properties.
- 6. If the media is HTML:
+ 5. If the media is HTML:
1. Decodes the HTML via the stored file.
2. Generates an Open Graph response from the HTML.
3. If an image exists in the Open Graph response:
@@ -42,6 +40,13 @@ When Synapse is asked to preview a URL it does the following:
provider and saves the local media metadata.
2. Generates thumbnails.
3. Updates the Open Graph response based on image properties.
+ 6. If the media is JSON and an oEmbed URL was found:
+ 1. Convert the oEmbed response to an Open Graph response.
+ 2. If a thumbnail or image is in the oEmbed response:
+ 1. Downloads the URL and stores it into a file via the media storage
+ provider and saves the local media metadata.
+ 2. Generates thumbnails.
+ 3. Updates the Open Graph response based on image properties.
7. Stores the result in the database cache.
4. Returns the result.
diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md
index c45eafcc4b..81574a015c 100644
--- a/docs/modules/spam_checker_callbacks.md
+++ b/docs/modules/spam_checker_callbacks.md
@@ -136,9 +136,9 @@ class IsUserEvilResource(Resource):
self.evil_users = config.get("evil_users") or []
def render_GET(self, request: Request):
- user = request.args.get(b"user")[0]
+ user = request.args.get(b"user")[0].decode()
request.setHeader(b"Content-Type", b"application/json")
- return json.dumps({"evil": user in self.evil_users})
+ return json.dumps({"evil": user in self.evil_users}).encode()
class ListSpamChecker:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 95cca16552..166cec38d3 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -2362,12 +2362,16 @@ user_directory:
#enabled: false
# Defines whether to search all users visible to your HS when searching
- # the user directory, rather than limiting to users visible in public
- # rooms. Defaults to false.
+ # the user directory. If false, search results will only contain users
+ # visible in public rooms and users sharing a room with the requester.
+ # Defaults to false.
#
- # If you set it true, you'll have to rebuild the user_directory search
- # indexes, see:
- # https://matrix-org.github.io/synapse/latest/user_directory.html
+ # NB. If you set this to true, and the last time the user_directory search
+ # indexes were (re)built was before Synapse 1.44, you'll have to
+ # rebuild the indexes in order to search through all known users.
+ # These indexes are built the first time Synapse starts; admins can
+ # manually trigger a rebuild following the instructions at
+ # https://matrix-org.github.io/synapse/latest/user_directory.html
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.
diff --git a/mypy.ini b/mypy.ini
index 09ffdda1b9..3cb6cecd7e 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -60,6 +60,7 @@ files =
synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
+ synapse/storage/databases/state,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/keys.py,
@@ -86,10 +87,14 @@ files =
tests/handlers/test_sync.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
+ tests/storage/test_state.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
-[mypy-synapse.rest.client.*]
+[mypy-synapse.handlers.*]
+disallow_untyped_defs = True
+
+[mypy-synapse.rest.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]
diff --git a/synapse/__init__.py b/synapse/__init__.py
index d62ccd1dbc..5f5cff1dfd 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.43.0rc1"
+__version__ = "1.43.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 05699714ee..e6ca9232ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -70,8 +70,8 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
- self._track_appservice_user_ips = hs.config.track_appservice_user_ips
- self._macaroon_secret_key = hs.config.macaroon_secret_key
+ self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
+ self._macaroon_secret_key = hs.config.key.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
async def check_user_in_room(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index e6bced93d5..a3b95f4de0 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -30,13 +30,15 @@ class AuthBlocking:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- self._server_notices_mxid = hs.config.server_notices_mxid
- self._hs_disabled = hs.config.hs_disabled
- self._hs_disabled_message = hs.config.hs_disabled_message
- self._admin_contact = hs.config.admin_contact
- self._max_mau_value = hs.config.max_mau_value
- self._limit_usage_by_mau = hs.config.limit_usage_by_mau
- self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+ self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
+ self._hs_disabled = hs.config.server.hs_disabled
+ self._hs_disabled_message = hs.config.server.hs_disabled_message
+ self._admin_contact = hs.config.server.admin_contact
+ self._max_mau_value = hs.config.server.max_mau_value
+ self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
+ self._mau_limits_reserved_threepids = (
+ hs.config.server.mau_limits_reserved_threepids
+ )
self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 236f0c7f99..39fd9954d5 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -121,7 +121,7 @@ class EventTypes:
SpaceParent = "m.space.parent"
MSC2716_INSERTION = "org.matrix.msc2716.insertion"
- MSC2716_CHUNK = "org.matrix.msc2716.chunk"
+ MSC2716_BATCH = "org.matrix.msc2716.batch"
MSC2716_MARKER = "org.matrix.msc2716.marker"
@@ -209,11 +209,11 @@ class EventContentFields:
# Used on normal messages to indicate they were historically imported after the fact
MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
- # For "insertion" events to indicate what the next chunk ID should be in
+ # For "insertion" events to indicate what the next batch ID should be in
# order to connect to it
- MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id"
- # Used on "chunk" events to indicate which insertion event it connects to
- MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id"
+ MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
+ # Used on "batch" events to indicate which insertion event it connects to
+ MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
# For "marker" events
MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 61d9c658a9..0a895bba48 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -244,24 +244,8 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
)
- MSC2716 = RoomVersion(
- "org.matrix.msc2716",
- RoomDisposition.UNSTABLE,
- EventFormatVersions.V3,
- StateResolutionVersions.V2,
- enforce_key_validity=True,
- special_case_aliases_auth=False,
- strict_canonicaljson=True,
- limit_notifications_power_levels=True,
- msc2176_redaction_rules=False,
- msc3083_join_rules=False,
- msc3375_redaction_rules=False,
- msc2403_knocking=True,
- msc2716_historical=True,
- msc2716_redactions=False,
- )
- MSC2716v2 = RoomVersion(
- "org.matrix.msc2716v2",
+ MSC2716v3 = RoomVersion(
+ "org.matrix.msc2716v3",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
@@ -289,9 +273,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V6,
RoomVersions.MSC2176,
RoomVersions.V7,
- RoomVersions.MSC2716,
RoomVersions.V8,
RoomVersions.V9,
+ RoomVersions.MSC2716v3,
)
}
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 0f5b2b3977..83994df798 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List
+from typing import Any, List, Tuple, Type
from synapse.util.module_loader import load_module
@@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
- self.password_providers: List[Any] = []
+ self.password_providers: List[Tuple[Type, Any]] = []
providers = []
# We want to be backwards compatible with the old `ldap_config`
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index b10df8a232..2552f688d0 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -45,12 +45,16 @@ class UserDirectoryConfig(Config):
#enabled: false
# Defines whether to search all users visible to your HS when searching
- # the user directory, rather than limiting to users visible in public
- # rooms. Defaults to false.
+ # the user directory. If false, search results will only contain users
+ # visible in public rooms and users sharing a room with the requester.
+ # Defaults to false.
#
- # If you set it true, you'll have to rebuild the user_directory search
- # indexes, see:
- # https://matrix-org.github.io/synapse/latest/user_directory.html
+ # NB. If you set this to true, and the last time the user_directory search
+ # indexes were (re)built was before Synapse 1.44, you'll have to
+ # rebuild the indexes in order to search through all known users.
+ # These indexes are built the first time Synapse starts; admins can
+ # manually trigger a rebuild following the instructions at
+ # https://matrix-org.github.io/synapse/latest/user_directory.html
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index c644b4dfc5..d310976fe3 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -102,7 +102,7 @@ class FederationPolicyForHTTPS:
self._config = config
# Check if we're using a custom list of a CA certificates
- trust_root = config.federation_ca_trust_root
+ trust_root = config.tls.federation_ca_trust_root
if trust_root is None:
# Use CA root certs provided by OpenSSL
trust_root = platformTrust()
@@ -113,7 +113,7 @@ class FederationPolicyForHTTPS:
# moving to TLS 1.2 by default, we want to respect the config option if
# it is set to 1.0 (which the alternate option, raiseMinimumTo, will not
# let us do).
- minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
+ minTLS = _TLS_VERSION_MAP[config.tls.federation_client_minimum_tls_version]
_verify_ssl = CertificateOptions(
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
@@ -125,10 +125,10 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
- self._should_verify = self._config.federation_verify_certificates
+ self._should_verify = self._config.tls.federation_verify_certificates
self._federation_certificate_verification_whitelist = (
- self._config.federation_certificate_verification_whitelist
+ self._config.tls.federation_certificate_verification_whitelist
)
def get_options(self, host: bytes):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 9e9b1c1c86..e1e13a2412 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -572,7 +572,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
- self.key_servers = self.config.key_servers
+ self.key_servers = self.config.key.key_servers
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index cb133f3f84..fc50a0e71a 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -213,7 +213,7 @@ def check(
if (
event.type == EventTypes.MSC2716_INSERTION
- or event.type == EventTypes.MSC2716_CHUNK
+ or event.type == EventTypes.MSC2716_BATCH
or event.type == EventTypes.MSC2716_MARKER
):
check_historical(room_version_obj, event, auth_events)
@@ -552,14 +552,14 @@ def check_historical(
auth_events: StateMap[EventBase],
) -> None:
"""Check whether the event sender is allowed to send historical related
- events like "insertion", "chunk", and "marker".
+ events like "insertion", "batch", and "marker".
Returns:
None
Raises:
AuthError if the event sender is not allowed to send historical related events
- ("insertion", "chunk", and "marker").
+ ("insertion", "batch", and "marker").
"""
# Ignore the auth checks in room versions that do not support historical
# events
@@ -573,7 +573,7 @@ def check_historical(
if user_level < historical_level:
raise AuthError(
403,
- 'You don\'t have permission to send send historical related events ("insertion", "chunk", and "marker")',
+ 'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
)
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index a730c1719a..49190459c8 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -344,6 +344,18 @@ class EventBase(metaclass=abc.ABCMeta):
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ return "<%s event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
+ self.__class__.__name__,
+ self.event_id,
+ self.get("type", None),
+ self.get("state_key", None),
+ self.internal_metadata.is_outlier(),
+ )
+
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1
@@ -392,17 +404,6 @@ class FrozenEvent(EventBase):
def event_id(self) -> str:
return self._event_id
- def __str__(self):
- return self.__repr__()
-
- def __repr__(self):
- return "<FrozenEvent event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
- self.get("event_id", None),
- self.get("type", None),
- self.get("state_key", None),
- self.internal_metadata.is_outlier(),
- )
-
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
@@ -478,17 +479,6 @@ class FrozenEventV2(EventBase):
"""
return self.auth_events
- def __str__(self):
- return self.__repr__()
-
- def __repr__(self):
- return "<%s event_id=%r, type=%r, state_key=%r>" % (
- self.__class__.__name__,
- self.event_id,
- self.get("type", None),
- self.get("state_key", None),
- )
-
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index fb22337e27..f86113a448 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -141,9 +141,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
add_fields("redacts")
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
- add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID)
- elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK:
- add_fields(EventContentFields.MSC2716_CHUNK_ID)
+ add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID)
+ elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
+ add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
add_fields(EventContentFields.MSC2716_MARKER_INSERTION)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 214ee948fa..638959cbec 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1237,7 +1237,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
- if not self.config.use_presence and edu_type == EduTypes.Presence:
+ if not self.config.server.use_presence and edu_type == EduTypes.Presence:
return
# Check if we have a handler on this instance
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4671ac0242..720d7bd74d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -594,7 +594,7 @@ class FederationSender(AbstractFederationSender):
destinations (list[str])
"""
- if not states or not self.hs.config.use_presence:
+ if not states or not self.hs.config.server.use_presence:
# No-op if presence is disabled.
return
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index c23ccd6dd9..0ccef884e7 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.ratelimiting import Ratelimiter
+from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -63,16 +64,21 @@ class BaseHandler:
self.event_builder_factory = hs.get_event_builder_factory()
- async def ratelimit(self, requester, update=True, is_admin_redaction=False):
+ async def ratelimit(
+ self,
+ requester: Requester,
+ update: bool = True,
+ is_admin_redaction: bool = False,
+ ) -> None:
"""Ratelimits requests.
Args:
- requester (Requester)
- update (bool): Whether to record that a request is being processed.
+ requester
+ update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
- is_admin_redaction (bool): Whether this is a room admin/moderator
+ is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index affb54e0ee..96273e2f81 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
-from typing import TYPE_CHECKING, List, Tuple
+from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
@@ -21,6 +21,7 @@ from synapse.replication.http.account_data import (
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
+from synapse.streams import EventSource
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@@ -163,7 +164,7 @@ class AccountDataHandler:
return response["max_stream_id"]
-class AccountDataEventSource:
+class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -171,7 +172,13 @@ class AccountDataEventSource:
return self.store.get_max_account_data_stream_id()
async def get_new_events(
- self, user: UserID, from_key: int, **kwargs
+ self,
+ user: UserID,
+ from_key: int,
+ limit: Optional[int],
+ room_ids: Collection[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index a9c2222f46..4724565ba5 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -99,7 +99,7 @@ class AccountValidityHandler:
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
- ):
+ ) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
@@ -165,7 +165,7 @@ class AccountValidityHandler:
return False
- async def on_user_registration(self, user_id: str):
+ async def on_user_registration(self, user_id: str) -> None:
"""Tell third-party modules about a user's registration.
Args:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index a7b5a4e9c9..b7213b67a5 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from prometheus_client import Counter
@@ -58,7 +58,7 @@ class ApplicationServicesHandler:
self.current_max = 0
self.is_processing = False
- def notify_interested_services(self, max_token: RoomStreamToken):
+ def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
@@ -82,7 +82,7 @@ class ApplicationServicesHandler:
self._notify_interested_services(max_token)
@wrap_as_background_process("notify_interested_services")
- async def _notify_interested_services(self, max_token: RoomStreamToken):
+ async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
@@ -100,7 +100,7 @@ class ApplicationServicesHandler:
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
- async def handle_event(event):
+ async def handle_event(event: EventBase) -> None:
# Gather interested services
services = await self._get_services_for_event(event)
if len(services) == 0:
@@ -116,9 +116,9 @@ class ApplicationServicesHandler:
if not self.started_scheduler:
- async def start_scheduler():
+ async def start_scheduler() -> None:
try:
- return await self.scheduler.start()
+ await self.scheduler.start()
except Exception:
logger.error("Application Services Failure")
@@ -137,7 +137,7 @@ class ApplicationServicesHandler:
"appservice_sender"
).observe((now - ts) / 1000)
- async def handle_room_events(events):
+ async def handle_room_events(events: Iterable[EventBase]) -> None:
for event in events:
await handle_event(event)
@@ -184,7 +184,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
- ):
+ ) -> None:
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
@@ -226,7 +226,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Collection[Union[str, UserID]],
- ):
+ ) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
@@ -254,7 +254,7 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
- typing_source = self.event_sources.sources["typing"]
+ typing_source = self.event_sources.sources.typing
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
service=service,
@@ -269,7 +269,7 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
- receipts_source = self.event_sources.sources["receipt"]
+ receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
)
@@ -279,7 +279,7 @@ class ApplicationServicesHandler:
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events: List[JsonDict] = []
- presence_source = self.event_sources.sources["presence"]
+ presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fbbf6fd834..bcd4249e09 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -29,6 +29,7 @@ from typing import (
Mapping,
Optional,
Tuple,
+ Type,
Union,
cast,
)
@@ -439,7 +440,7 @@ class AuthHandler(BaseHandler):
return ui_auth_types
- def get_enabled_auth_types(self):
+ def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types
Returns the UI-Auth types which are supported by the homeserver's current
@@ -702,7 +703,7 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- async def _expire_old_sessions(self):
+ async def _expire_old_sessions(self) -> None:
"""
Invalidate any user interactive authentication sessions that have expired.
"""
@@ -1347,12 +1348,12 @@ class AuthHandler(BaseHandler):
try:
res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
- raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+ raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(res.user_id)
return res
- async def delete_access_token(self, access_token: str):
+ async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token
Args:
@@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler):
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
- ):
+ ) -> None:
"""Invalidate access tokens belonging to a user
Args:
@@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler):
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
- ):
+ ) -> None:
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler):
Hashed password.
"""
- def _do_hash():
+ def _do_hash() -> str:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
- def _do_validate_hash(checked_hash: bytes):
+ def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler):
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
- ):
+ ) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
Args:
@@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler):
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
- ):
+ ) -> None:
"""
The synchronous portion of complete_sso_login.
@@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler):
del self._extra_attributes[user_id]
@staticmethod
- def add_query_param_to_url(url: str, param_name: str, param: Any):
+ def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
@@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler):
return urllib.parse.urlunparse(url_parts)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator:
- hs = attr.ib()
+ hs: "HomeServer"
def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
@@ -1816,7 +1817,9 @@ class PasswordProvider:
"""
@classmethod
- def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ def load(
+ cls, module: Type, config: JsonDict, module_api: ModuleApi
+ ) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
@@ -1824,7 +1827,7 @@ class PasswordProvider:
raise
return cls(pp, module_api)
- def __init__(self, pp, module_api: ModuleApi):
+ def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
@@ -1838,7 +1841,7 @@ class PasswordProvider:
if g:
self._supported_login_types.update(g())
- def __str__(self):
+ def __str__(self) -> str:
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@@ -1876,19 +1879,19 @@ class PasswordProvider:
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
- g = getattr(self._pp, "check_password", None)
- if g:
+ check_password = getattr(self._pp, "check_password", None)
+ if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
- is_valid = await self._pp.check_password(
+ is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
- g = getattr(self._pp, "check_auth", None)
- if not g:
+ check_auth = getattr(self._pp, "check_auth", None)
+ if not check_auth:
return None
- result = await g(username, login_type, login_dict)
+ result = await check_auth(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 47ddabbe46..b0b188dc78 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -34,20 +34,20 @@ logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket."""
- def __init__(self, error, error_description=None):
+ def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
- def __str__(self):
+ def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse:
- username = attr.ib(type=str)
- attributes = attr.ib(type=Dict[str, List[Optional[str]]])
+ username: str
+ attributes: Dict[str, List[Optional[str]]]
class CasHandler:
@@ -133,11 +133,9 @@ class CasHandler:
body = pde.response
except HttpResponseException as e:
description = (
- (
- 'Authorization server responded with a "{status}" error '
- "while exchanging the authorization code."
- ).format(status=e.code),
- )
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=e.code)
raise CasError("server_error", description) from e
return self._parse_cas_response(body)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index dcd320c555..a03ff9842b 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -257,11 +257,8 @@ class DeactivateAccountHandler(BaseHandler):
"""
# Add the user to the directory, if necessary.
user = UserID.from_string(user_id)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(user.localpart)
- await self.user_directory_handler.handle_local_profile_change(
- user_id, profile
- )
+ profile = await self.store.get_profileinfo(user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(user_id, profile)
# Ensure the user is not marked as erased.
await self.store.mark_user_not_erased(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 46ee834407..35334725d7 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -267,7 +267,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
- def _check_device_name_length(self, name: Optional[str]):
+ def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 08a137561f..d0fb2fc7dc 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -202,7 +202,7 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache
@trace
- async def do_remote_query(destination):
+ async def do_remote_query(destination: str) -> None:
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
@@ -447,7 +447,7 @@ class E2eKeysHandler:
}
@trace
- async def claim_client_keys(destination):
+ async def claim_client_keys(destination: str) -> None:
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 4288ffff09..cb81fa0986 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
+from synapse.events.snapshot import EventContext
from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure
@@ -45,7 +46,11 @@ class EventAuthHandler:
self._server_name = hs.hostname
async def check_from_context(
- self, room_version: str, event, context, do_sig_check=True
+ self,
+ room_version: str,
+ event: EventBase,
+ context: EventContext,
+ do_sig_check: bool = True,
) -> None:
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6754c64c31..8e2cf3387a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1221,136 +1221,6 @@ class FederationHandler(BaseHandler):
return missing_events
- async def construct_auth_difference(
- self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
- ) -> Dict:
- """Given a local and remote auth chain, find the differences. This
- assumes that we have already processed all events in remote_auth
-
- Params:
- local_auth
- remote_auth
-
- Returns:
- dict
- """
-
- logger.debug("construct_auth_difference Start!")
-
- # TODO: Make sure we are OK with local_auth or remote_auth having more
- # auth events in them than strictly necessary.
-
- def sort_fun(ev):
- return ev.depth, ev.event_id
-
- logger.debug("construct_auth_difference after sort_fun!")
-
- # We find the differences by starting at the "bottom" of each list
- # and iterating up on both lists. The lists are ordered by depth and
- # then event_id, we iterate up both lists until we find the event ids
- # don't match. Then we look at depth/event_id to see which side is
- # missing that event, and iterate only up that list. Repeat.
-
- remote_list = list(remote_auth)
- remote_list.sort(key=sort_fun)
-
- local_list = list(local_auth)
- local_list.sort(key=sort_fun)
-
- local_iter = iter(local_list)
- remote_iter = iter(remote_list)
-
- logger.debug("construct_auth_difference before get_next!")
-
- def get_next(it, opt=None):
- try:
- return next(it)
- except Exception:
- return opt
-
- current_local = get_next(local_iter)
- current_remote = get_next(remote_iter)
-
- logger.debug("construct_auth_difference before while")
-
- missing_remotes = []
- missing_locals = []
- while current_local or current_remote:
- if current_remote is None:
- missing_locals.append(current_local)
- current_local = get_next(local_iter)
- continue
-
- if current_local is None:
- missing_remotes.append(current_remote)
- current_remote = get_next(remote_iter)
- continue
-
- if current_local.event_id == current_remote.event_id:
- current_local = get_next(local_iter)
- current_remote = get_next(remote_iter)
- continue
-
- if current_local.depth < current_remote.depth:
- missing_locals.append(current_local)
- current_local = get_next(local_iter)
- continue
-
- if current_local.depth > current_remote.depth:
- missing_remotes.append(current_remote)
- current_remote = get_next(remote_iter)
- continue
-
- # They have the same depth, so we fall back to the event_id order
- if current_local.event_id < current_remote.event_id:
- missing_locals.append(current_local)
- current_local = get_next(local_iter)
-
- if current_local.event_id > current_remote.event_id:
- missing_remotes.append(current_remote)
- current_remote = get_next(remote_iter)
- continue
-
- logger.debug("construct_auth_difference after while")
-
- # missing locals should be sent to the server
- # We should find why we are missing remotes, as they will have been
- # rejected.
-
- # Remove events from missing_remotes if they are referencing a missing
- # remote. We only care about the "root" rejected ones.
- missing_remote_ids = [e.event_id for e in missing_remotes]
- base_remote_rejected = list(missing_remotes)
- for e in missing_remotes:
- for e_id in e.auth_event_ids():
- if e_id in missing_remote_ids:
- try:
- base_remote_rejected.remove(e)
- except ValueError:
- pass
-
- reason_map = {}
-
- for e in base_remote_rejected:
- reason = await self.store.get_rejection_reason(e.event_id)
- if reason is None:
- # TODO: e is not in the current state, so we should
- # construct some proof of that.
- continue
-
- reason_map[e.event_id] = reason
-
- logger.debug("construct_auth_difference returning")
-
- return {
- "auth_chain": local_auth,
- "rejects": {
- e.event_id: {"reason": reason_map[e.event_id], "proof": None}
- for e in base_remote_rejected
- },
- "missing": [e.event_id for e in missing_locals],
- }
-
@log_function
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 946343fa25..3b95beeb08 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1016,7 +1016,7 @@ class FederationEventHandler:
except Exception:
logger.exception("Failed to resync device for %s", sender)
- async def _handle_marker_event(self, origin: str, marker_event: EventBase):
+ async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker
event that points to one.
@@ -1109,7 +1109,7 @@ class FederationEventHandler:
event_map: Dict[str, EventBase] = {}
- async def get_event(event_id: str):
+ async def get_event(event_id: str) -> None:
with nested_logging_context(event_id):
try:
event = await self._federation_client.get_pdu(
@@ -1218,7 +1218,7 @@ class FederationEventHandler:
if not event_infos:
return
- async def prep(ev_info: _NewEventInfo):
+ async def prep(ev_info: _NewEventInfo) -> EventContext:
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self._state_handler.compute_event_context(event)
@@ -1692,7 +1692,7 @@ class FederationEventHandler:
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ):
+ ) -> None:
"""Run the push actions for a received event, and persist it.
Args:
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 1a6c5c64a2..9e270d461b 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, JsonDict, get_domain_from_id
@@ -25,12 +25,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _create_rerouter(func_name):
+def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
"""Returns an async function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
- async def f(self, group_id, *args, **kwargs):
+ async def f(
+ self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
+ ) -> JsonDict:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 4e8f7f1d85..9ad39a65d8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.internet import defer
@@ -125,7 +125,7 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token()
- presence_stream = self.hs.get_event_sources().sources["presence"]
+ presence_stream = self.hs.get_event_sources().sources.presence
presence, _ = await presence_stream.get_new_events(
user, from_key=None, include_offline=False
)
@@ -150,7 +150,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- async def handle_room(event: RoomsForUser):
+ async def handle_room(event: RoomsForUser) -> None:
d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
@@ -411,9 +411,9 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
- async def get_presence():
+ async def get_presence() -> List[JsonDict]:
# If presence is disabled, return an empty list
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return []
states = await presence_handler.get_states(
@@ -428,7 +428,7 @@ class InitialSyncHandler(BaseHandler):
for s in states
]
- async def get_receipts():
+ async def get_receipts() -> List[JsonDict]:
receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 60673cd4b8..bf2763b0f3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
+from synapse.handlers.directory import DirectoryHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@@ -298,7 +299,7 @@ class MessageHandler:
for user_id, profile in users_with_profile.items()
}
- def maybe_schedule_expiry(self, event: EventBase):
+ def maybe_schedule_expiry(self, event: EventBase) -> None:
"""Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided
timestamp.
@@ -318,7 +319,7 @@ class MessageHandler:
# a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts)
- async def _schedule_next_expiry(self):
+ async def _schedule_next_expiry(self) -> None:
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it.
@@ -331,7 +332,7 @@ class MessageHandler:
event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts)
- def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
+ def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None:
"""Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one.
@@ -367,7 +368,7 @@ class MessageHandler:
event_id,
)
- async def _expire_event(self, event_id: str):
+ async def _expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date
@@ -1229,7 +1230,10 @@ class EventCreationHandler:
self._external_cache_joined_hosts_updates[state_entry.state_group] = None
async def _validate_canonical_alias(
- self, directory_handler, room_alias_str: str, expected_room_id: str
+ self,
+ directory_handler: DirectoryHandler,
+ room_alias_str: str,
+ expected_room_id: str,
) -> None:
"""
Ensure that the given room alias points to the expected room ID.
@@ -1421,7 +1425,7 @@ class EventCreationHandler:
# structural protocol level).
is_msc2716_event = (
original_event.type == EventTypes.MSC2716_INSERTION
- or original_event.type == EventTypes.MSC2716_CHUNK
+ or original_event.type == EventTypes.MSC2716_BATCH
or original_event.type == EventTypes.MSC2716_MARKER
)
if not room_version_obj.msc2716_historical and is_msc2716_event:
@@ -1477,7 +1481,7 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
- def _notify():
+ def _notify() -> None:
try:
self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
@@ -1523,7 +1527,7 @@ class EventCreationHandler:
except Exception:
logger.exception("Error bumping presence active time")
- async def _send_dummy_events_to_fill_extremities(self):
+ async def _send_dummy_events_to_fill_extremities(self) -> None:
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
@@ -1600,7 +1604,7 @@ class EventCreationHandler:
)
return False
- def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
+ def _expire_rooms_to_exclude_from_dummy_event_insertion(self) -> None:
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set()
for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index dfc251b2a5..aed5a40a78 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode, urlparse
import attr
@@ -249,11 +249,11 @@ class OidcHandler:
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint"""
- def __init__(self, error, error_description=None):
+ def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
- def __str__(self):
+ def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
@@ -1057,13 +1057,13 @@ class JwtClientSecret:
self._cached_secret = b""
self._cached_secret_replacement_time = 0
- def __str__(self):
+ def __str__(self) -> str:
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here.
return self._get_secret().decode("ascii")
- def __bytes__(self):
+ def __bytes__(self) -> bytes:
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here.
return self._get_secret()
@@ -1197,21 +1197,21 @@ class OidcSessionTokenGenerator:
)
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# the Identity Provider being used
- idp_id = attr.ib(type=str)
+ idp_id: str
# The `nonce` parameter passed to the OIDC provider.
- nonce = attr.ib(type=str)
+ nonce: str
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
- client_redirect_url = attr.ib(type=str)
+ client_redirect_url: str
# The session ID of the ongoing UI Auth ("" if this is a login)
- ui_auth_session_id = attr.ib(type=str)
+ ui_auth_session_id: str
class UserAttributeDict(TypedDict):
@@ -1290,20 +1290,20 @@ class OidcMappingProvider(Generic[C]):
# Used to clear out "None" values in templates
-def jinja_finalize(thing):
+def jinja_finalize(thing: Any) -> Any:
return thing if thing is not None else ""
env = Environment(finalize=jinja_finalize)
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig:
- subject_claim = attr.ib(type=str)
- localpart_template = attr.ib(type=Optional[Template])
- display_name_template = attr.ib(type=Optional[Template])
- email_template = attr.ib(type=Optional[Template])
- extra_attributes = attr.ib(type=Dict[str, Template])
+ subject_claim: str
+ localpart_template: Optional[Template]
+ display_name_template: Optional[Template]
+ email_template: Optional[Template]
+ extra_attributes: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 7dc0ee4bef..08b93b3ec1 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,6 +15,8 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
+import attr
+
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
@@ -24,7 +26,7 @@ from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import Requester
+from synapse.types import JsonDict, Requester
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -36,15 +38,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, auto_attribs=True)
class PurgeStatus:
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
-
- Attributes:
- status (int): Tracks whether this request has completed. One of
- STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
@@ -57,10 +56,10 @@ class PurgeStatus:
STATUS_FAILED: "failed",
}
- def __init__(self):
- self.status = PurgeStatus.STATUS_ACTIVE
+ # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}.
+ status: int = STATUS_ACTIVE
- def asdict(self):
+ def asdict(self) -> JsonDict:
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
@@ -107,7 +106,7 @@ class PaginationHandler:
async def purge_history_for_rooms_in_range(
self, min_ms: Optional[int], max_ms: Optional[int]
- ):
+ ) -> None:
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@@ -291,7 +290,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes
- def clear_purge():
+ def clear_purge() -> None:
del self._purges_by_id[purge_id]
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 39b39cd3e2..983c837c66 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -26,18 +26,22 @@ import contextlib
import logging
from bisect import bisect
from contextlib import contextmanager
+from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Collection,
Dict,
FrozenSet,
+ Generator,
Iterable,
List,
Optional,
Set,
Tuple,
+ Type,
Union,
)
@@ -61,6 +65,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
+from synapse.streams import EventSource
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -240,7 +245,7 @@ class BasePresenceHandler(abc.ABC):
"""
@abc.abstractmethod
- async def bump_presence_active_time(self, user: UserID):
+ async def bump_presence_active_time(self, user: UserID) -> None:
"""We've seen the user do something that indicates they're interacting
with the app.
"""
@@ -274,7 +279,7 @@ class BasePresenceHandler(abc.ABC):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
- ):
+ ) -> None:
"""Process streams received over replication."""
await self._federation_queue.process_replication_rows(
stream_name, instance_name, token, rows
@@ -286,7 +291,7 @@ class BasePresenceHandler(abc.ABC):
async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState]
- ):
+ ) -> None:
"""If this instance is a federation sender, send the states to all
destinations that are interested. Filters out any states for remote
users.
@@ -309,7 +314,7 @@ class BasePresenceHandler(abc.ABC):
for destination, host_states in hosts_to_states.items():
self._federation.send_presence_to_destinations(host_states, [destination])
- async def send_full_presence_to_users(self, user_ids: Collection[str]):
+ async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
"""
Adds to the list of users who should receive a full snapshot of presence
upon their next sync. Note that this only works for local users.
@@ -363,7 +368,12 @@ class BasePresenceHandler(abc.ABC):
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
pass
@@ -374,7 +384,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0]
- self._presence_enabled = hs.config.use_presence
+ self._presence_enabled = hs.config.server.use_presence
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
@@ -468,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
if self._user_to_num_current_syncs[user_id] == 1:
self.mark_as_coming_online(user_id)
- def _end():
+ def _end() -> None:
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
@@ -480,7 +490,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.mark_as_going_offline(user_id)
@contextlib.contextmanager
- def _user_syncing():
+ def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@@ -503,7 +513,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
- ):
+ ) -> None:
await super().process_replication_rows(stream_name, instance_name, token, rows)
if stream_name != PresenceStream.NAME:
@@ -584,7 +594,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id = target_user.to_string()
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@@ -601,7 +611,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@@ -618,7 +628,7 @@ class PresenceHandler(BasePresenceHandler):
self.server_name = hs.hostname
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
- self._presence_enabled = hs.config.use_presence
+ self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry()
@@ -689,7 +699,7 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
- def run_timeout_handler():
+ def run_timeout_handler() -> Awaitable[None]:
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
@@ -698,7 +708,7 @@ class PresenceHandler(BasePresenceHandler):
30, self.clock.looping_call, run_timeout_handler, 5000
)
- def run_persister():
+ def run_persister() -> Awaitable[None]:
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
@@ -916,7 +926,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
return
user_id = user.to_string()
@@ -942,14 +952,14 @@ class PresenceHandler(BasePresenceHandler):
when users disconnect/reconnect.
Args:
- user_id (str)
- affect_presence (bool): If false this function will be a no-op.
+ user_id
+ affect_presence: If false this function will be a no-op.
Useful for streams that are not associated with an actual
client that is being used by a user.
"""
# Override if it should affect the user's presence, if presence is
# disabled.
- if not self.hs.config.use_presence:
+ if not self.hs.config.server.use_presence:
affect_presence = False
if affect_presence:
@@ -978,7 +988,7 @@ class PresenceHandler(BasePresenceHandler):
]
)
- async def _end():
+ async def _end() -> None:
try:
self.user_to_num_current_syncs[user_id] -= 1
@@ -994,7 +1004,7 @@ class PresenceHandler(BasePresenceHandler):
logger.exception("Error updating presence after sync")
@contextmanager
- def _user_syncing():
+ def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@@ -1264,7 +1274,7 @@ class PresenceHandler(BasePresenceHandler):
if self._event_processing:
return
- async def _process_presence():
+ async def _process_presence() -> None:
assert not self._event_processing
self._event_processing = True
@@ -1491,7 +1501,7 @@ def format_user_presence_state(
return content
-class PresenceEventSource:
+class PresenceEventSource(EventSource[int, UserPresenceState]):
def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
@@ -1510,10 +1520,11 @@ class PresenceEventSource:
self,
user: UserID,
from_key: Optional[int],
+ limit: Optional[int] = None,
room_ids: Optional[List[str]] = None,
- include_offline: bool = True,
+ is_guest: bool = False,
explicit_room_id: Optional[str] = None,
- **kwargs,
+ include_offline: bool = True,
) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
@@ -2074,7 +2085,7 @@ class PresenceFederationQueue:
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
- def _clear_queue(self):
+ def _clear_queue(self) -> None:
"""Clear out older entries from the queue."""
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
@@ -2205,7 +2216,7 @@ class PresenceFederationQueue:
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
- ):
+ ) -> None:
if stream_name != PresenceFederationStream.NAME:
return
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 51adf8762d..f06070bfcf 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -214,11 +214,10 @@ class ProfileHandler(BaseHandler):
target_user.localpart, displayname_to_set
)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(target_user.localpart)
- await self.user_directory_handler.handle_local_profile_change(
- target_user.to_string(), profile
- )
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
await self._update_join_states(requester, target_user)
@@ -254,7 +253,7 @@ class ProfileHandler(BaseHandler):
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
- ):
+ ) -> None:
"""Set a new avatar URL for a user.
Args:
@@ -300,11 +299,10 @@ class ProfileHandler(BaseHandler):
target_user.localpart, avatar_url_to_set
)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(target_user.localpart)
- await self.user_directory_handler.handle_local_profile_change(
- target_user.to_string(), profile
- )
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ target_user.to_string(), profile
+ )
await self._update_join_states(requester, target_user)
@@ -425,7 +423,7 @@ class ProfileHandler(BaseHandler):
raise
@wrap_as_background_process("Update remote profile")
- async def _update_remote_profile_cache(self):
+ async def _update_remote_profile_cache(self) -> None:
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a49b8ee4b1..5881f09ebd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
+from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
if TYPE_CHECKING:
@@ -162,7 +163,7 @@ class ReceiptsHandler(BaseHandler):
await self.federation_sender.send_read_receipt(receipt)
-class ReceiptEventSource:
+class ReceiptEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.config
@@ -216,7 +217,13 @@ class ReceiptEventSource:
return visible_events
async def get_new_events(
- self, from_key: int, room_ids: List[str], user: UserID, **kwargs
+ self,
+ user: UserID,
+ from_key: int,
+ limit: Optional[int],
+ room_ids: Iterable[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 38c4993da0..1c195c65db 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -125,7 +125,7 @@ class RegistrationHandler(BaseHandler):
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
- ):
+ ) -> None:
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@@ -295,11 +295,10 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
- if self.hs.config.user_directory_search_all_users:
- profile = await self.store.get_profileinfo(localpart)
- await self.user_directory_handler.handle_local_profile_change(
- user_id, profile
- )
+ profile = await self.store.get_profileinfo(localpart)
+ await self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
else:
# autogen a sequential user ID
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9345ae02e0..287ea2fd06 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,6 +1,4 @@
-# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +20,16 @@ import math
import random
import string
from collections import OrderedDict
-from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+)
from synapse.api.constants import (
EventContentFields,
@@ -49,6 +56,7 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
+from synapse.streams import EventSource
from synapse.types import (
JsonDict,
MutableStateMap,
@@ -186,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
- ):
+ ) -> str:
"""
Args:
requester: the user requesting the upgrade
@@ -512,7 +520,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
- ):
+ ) -> None:
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
@@ -902,7 +910,7 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
- def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
+ def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -910,7 +918,7 @@ class RoomCreationHandler(BaseHandler):
return e
- async def send(etype: str, content: JsonDict, **kwargs) -> int:
+ async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to
@@ -1033,7 +1041,7 @@ class RoomCreationHandler(BaseHandler):
creator_id: str,
is_public: bool,
room_version: RoomVersion,
- ):
+ ) -> str:
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
@@ -1097,7 +1105,7 @@ class RoomContextHandler:
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
- async def filter_evts(events):
+ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if use_admin_priviledge:
return events
return await filter_events_for_client(
@@ -1175,7 +1183,7 @@ class RoomContextHandler:
return results
-class RoomEventSource:
+class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -1183,8 +1191,8 @@ class RoomEventSource:
self,
user: UserID,
from_key: RoomStreamToken,
- limit: int,
- room_ids: List[str],
+ limit: Optional[int],
+ room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 81680b8dfa..c83ff585e3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -14,7 +14,7 @@
import logging
from collections import namedtuple
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@@ -33,7 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.types import JsonDict, ThirdPartyInstanceID
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
@@ -169,7 +169,7 @@ class RoomListHandler(BaseHandler):
ignore_non_federatable=from_federation,
)
- def build_room_entry(room):
+ def build_room_entry(room: JsonDict) -> JsonDict:
entry = {
"room_id": room["room_id"],
"name": room["name"],
@@ -249,10 +249,10 @@ class RoomListHandler(BaseHandler):
self,
room_id: str,
num_joined_users: int,
- cache_context,
+ cache_context: _CacheContext,
with_alias: bool = True,
allow_private: bool = False,
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""Returns the entry for a room
Args:
@@ -507,7 +507,7 @@ class RoomListNextBatch(
)
)
- def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
+ def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
return self._replace(**kwds)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 826267f47a..4969ee395b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -226,7 +226,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Optional[str],
n_invites: int,
update: bool = True,
- ):
+ ) -> None:
"""Ratelimit more than one invite sent by the given requester in the given room.
Args:
@@ -250,7 +250,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Optional[Requester],
room_id: Optional[str],
invitee_user_id: str,
- ):
+ ) -> None:
"""Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user.
@@ -387,7 +387,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result_event.event_id, result_event.internal_metadata.stream_ordering
async def copy_room_tags_and_direct_to_room(
- self, old_room_id, new_room_id, user_id
+ self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
"""Copies the tags and direct room state from one room to another.
@@ -688,7 +688,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
" (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE,
)
- if old_membership == "ban" and action != "unban":
+ if old_membership == "ban" and action not in ["ban", "unban", "leave"]:
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
@@ -1050,7 +1050,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
- ):
+ ) -> None:
"""
Change the membership status of a user in a room.
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 781da9e811..4e28fb9685 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -541,7 +541,7 @@ class RoomSummaryHandler:
origin: str,
requested_room_id: str,
suggested_only: bool,
- ):
+ ) -> JsonDict:
"""
Implementation of the room hierarchy Federation API.
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 0066d570c5..185befbe9f 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -40,15 +40,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
- creation_time = attr.ib()
+ creation_time: int
# The user interactive authentication session ID associated with this SAML
# session (or None if this SAML session is for an initial login).
- ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+ ui_auth_session_id: Optional[str] = None
class SamlHandler(BaseHandler):
@@ -359,7 +359,7 @@ class SamlHandler(BaseHandler):
return remote_user_id
- def expire_sessions(self):
+ def expire_sessions(self) -> None:
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
@@ -391,10 +391,10 @@ MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
}
-@attr.s
+@attr.s(auto_attribs=True)
class SamlConfig:
- mxid_source_attribute = attr.ib()
- mxid_mapper = attr.ib()
+ mxid_source_attribute: str
+ mxid_mapper: Callable[[str], str]
class DefaultSamlMappingProvider:
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index a31fe3e3c7..25e6b012b7 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -17,7 +17,7 @@ import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from io import BytesIO
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Any, Optional
from pkg_resources import parse_version
@@ -79,7 +79,7 @@ async def _sendmail(
msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred()
- def build_sender_factory(**kwargs) -> ESMTPSenderFactory:
+ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
return ESMTPSenderFactory(
username,
password,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 05aa76d6a6..e044251a13 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -205,7 +205,7 @@ class SsoHandler:
self._consent_at_registration = hs.config.consent.user_consent_at_registration
- def register_identity_provider(self, p: SsoIdentityProvider):
+ def register_identity_provider(self, p: SsoIdentityProvider) -> None:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
@@ -856,7 +856,7 @@ class SsoHandler:
async def handle_terms_accepted(
self, request: Request, session_id: str, terms_version: str
- ):
+ ) -> None:
"""Handle a request to the new-user 'consent' endpoint
Will serve an HTTP response to the request.
@@ -959,7 +959,7 @@ class SsoHandler:
new_user=True,
)
- def _expire_old_sessions(self):
+ def _expire_old_sessions(self) -> None:
to_expire = []
now = int(self._clock.time_msec())
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index b64ce8cab8..9fc53333fc 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -68,7 +68,7 @@ class StatsHandler:
self._is_processing = True
- async def process():
+ async def process() -> None:
try:
await self._unsafe_process()
finally:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index edfdb99cbd..2c7c6d63a9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -364,7 +364,9 @@ class SyncHandler:
)
else:
- async def current_sync_callback(before_token, after_token) -> SyncResult:
+ async def current_sync_callback(
+ before_token: StreamToken, after_token: StreamToken
+ ) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events(
@@ -441,7 +443,7 @@ class SyncHandler:
room_ids = sync_result_builder.joined_room_ids
- typing_source = self.event_sources.sources["typing"]
+ typing_source = self.event_sources.sources.typing
typing, typing_key = await typing_source.get_new_events(
user=sync_config.user,
from_key=typing_key,
@@ -463,7 +465,7 @@ class SyncHandler:
receipt_key = since_token.receipt_key if since_token else 0
- receipt_source = self.event_sources.sources["receipt"]
+ receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
@@ -1090,7 +1092,7 @@ class SyncHandler:
block_all_presence_data = (
since_token is None and sync_config.filter_collection.blocks_all_presence()
)
- if self.hs_config.use_presence and not block_all_presence_data:
+ if self.hs_config.server.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
sync_result_builder,
@@ -1413,7 +1415,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config
user = sync_result_builder.sync_config.user
- presence_source = self.event_sources.sources["presence"]
+ presence_source = self.event_sources.sources.presence
since_token = sync_result_builder.since_token
presence_key = None
@@ -1532,9 +1534,9 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
- async def handle_room_entries(room_entry: "RoomSyncResultBuilder"):
+ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
- res = await self._generate_room_entry(
+ await self._generate_room_entry(
sync_result_builder,
ignored_users,
room_entry,
@@ -1544,7 +1546,6 @@ class SyncHandler:
always_include=sync_result_builder.full_state,
)
logger.debug("Generated room entry for %s", room_entry.room_id)
- return res
await concurrently_execute(handle_room_entries, room_entries, 10)
@@ -1925,7 +1926,7 @@ class SyncHandler:
tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
- ):
+ ) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 9cea011e62..9326330c90 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.replication.tcp.streams import TypingStream
+from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
@@ -439,7 +440,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication")
-class TypingNotificationEventSource:
+class TypingNotificationEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
@@ -485,7 +486,13 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
async def get_new_events(
- self, from_key: int, room_ids: Iterable[str], **kwargs
+ self,
+ user: UserID,
+ from_key: int,
+ limit: Optional[int],
+ room_ids: Iterable[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index d3828dec6b..ea9325e96a 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -70,7 +70,7 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS
- def is_enabled(self):
+ def is_enabled(self) -> bool:
return True
async def check_auth(self, authdict: dict, clientip: str) -> Any:
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 6faa1d84be..8dc46d7674 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -114,7 +114,7 @@ class UserDirectoryHandler(StateDeltasHandler):
if self._is_processing:
return
- async def process():
+ async def process() -> None:
try:
await self._unsafe_process()
finally:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index c2ea51ee16..5204c3d08c 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -321,8 +321,11 @@ class SimpleHttpClient:
self.user_agent = hs.version_string
self.clock = hs.get_clock()
- if hs.config.user_agent_suffix:
- self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+ if hs.config.server.user_agent_suffix:
+ self.user_agent = "%s %s" % (
+ self.user_agent,
+ hs.config.server.user_agent_suffix,
+ )
# We use this for our body producers to ensure that they use the correct
# reactor.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2e9898997c..ef10ec0937 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -66,7 +66,7 @@ from synapse.http.client import (
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging import opentracing
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
@@ -553,20 +553,29 @@ class MatrixFederationHttpClient:
with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent.
- request_deferred = self.agent.request(
+
+ # To preserve the logging context, the timeout is treated
+ # in a similar way to `defer.gatherResults`:
+ # * Each logging context-preserving fork is wrapped in
+ # `run_in_background`. In this case there is only one,
+ # since the timeout fork is not logging-context aware.
+ # * The `Deferred` that joins the forks back together is
+ # wrapped in `make_deferred_yieldable` to restore the
+ # logging context regardless of the path taken.
+ request_deferred = run_in_background(
+ self.agent.request,
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
-
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.reactor,
)
- response = await request_deferred
+ response = await make_deferred_yieldable(request_deferred)
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c665a9d5db..dd4c749e16 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -21,7 +21,7 @@ from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
-from twisted.web.resource import IResource
+from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@@ -61,7 +61,7 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
- def __init__(self, channel, *args, max_request_body_size=1024, **kw):
+ def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site
@@ -83,13 +83,13 @@ class SynapseRequest(Request):
self._is_processing = False
# the time when the asynchronous request handler completed its processing
- self._processing_finished_time = None
+ self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection
# dropped)
- self.finish_time = None
+ self.finish_time: Optional[float] = None
- def __repr__(self):
+ def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
@@ -100,7 +100,7 @@ class SynapseRequest(Request):
self.site.site_tag,
)
- def handleContentChunk(self, data):
+ def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
@@ -139,7 +139,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
- def get_request_id(self):
+ def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str:
@@ -205,7 +205,7 @@ class SynapseRequest(Request):
return None, None
- def render(self, resrc):
+ def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@@ -282,7 +282,7 @@ class SynapseRequest(Request):
if self.finish_time is not None:
self._finished_processing()
- def finish(self):
+ def finish(self) -> None:
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
@@ -295,7 +295,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
@@ -327,7 +327,7 @@ class SynapseRequest(Request):
if not self._is_processing:
self._finished_processing()
- def _started_processing(self, servlet_name):
+ def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
@@ -354,9 +354,11 @@ class SynapseRequest(Request):
self.get_redacted_uri(),
)
- def _finished_processing(self):
+ def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
+ assert self.finish_time is not None
+
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
@@ -437,7 +439,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False
- def requestReceived(self, command, path, version):
+ def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
@@ -445,7 +447,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers()
return super().requestReceived(command, path, version)
- def _process_forwarded_headers(self):
+ def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
@@ -470,7 +472,7 @@ class XForwardedForRequest(SynapseRequest):
)
self._forwarded_https = True
- def isSecure(self):
+ def isSecure(self) -> bool:
if self._forwarded_https:
return True
return super().isSecure()
@@ -545,14 +547,16 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
- def request_factory(channel, queued) -> Request:
+ def request_factory(channel, queued: bool) -> Request:
return request_class(
- channel, max_request_body_size=max_request_body_size, queued=queued
+ channel,
+ max_request_body_size=max_request_body_size,
+ queued=queued,
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
- def log(self, request):
+ def log(self, request: SynapseRequest) -> None:
pass
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2d403532fa..3196c2bec6 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -91,7 +91,7 @@ class ModuleApi:
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
- self._presence_stream = hs.get_event_sources().sources["presence"]
+ self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock()
self._send_email_handler = hs.get_send_email_handler()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index bbe337949a..1a9f84ba45 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -584,7 +584,7 @@ class Notifier:
events: List[EventBase] = []
end_token = from_token
- for name, source in self.event_sources.sources.items():
+ for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 250a4861b0..33430b167c 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -370,7 +370,7 @@ class HttpPusher(Pusher):
if event.type == "m.room.member" and event.is_state():
d["notification"]["membership"] = event.content["membership"]
d["notification"]["user_is_target"] = event.state_key == self.user_id
- if self.hs.config.push_include_content and event.content:
+ if self.hs.config.push.push_include_content and event.content:
d["notification"]["content"] = event.content
# We no longer send aliases separately, instead, we send the human
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b89c6e6f2b..e38e3c5d44 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -110,7 +110,7 @@ class Mailer:
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name
- self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
+ self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
logger.info("Created Mailer for app_name %s" % app_name)
@@ -796,8 +796,8 @@ class Mailer:
Returns:
A link to open a room in the web client.
"""
- if self.hs.config.email_riot_base_url:
- base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
+ if self.hs.config.email.email_riot_base_url:
+ base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
@@ -815,9 +815,9 @@ class Mailer:
Returns:
A link to open the notification in the web client.
"""
- if self.hs.config.email_riot_base_url:
+ if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % (
- self.hs.config.email_riot_base_url,
+ self.hs.config.email.email_riot_base_url,
notif["room_id"],
notif["event_id"],
)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 021275437c..29ed346d37 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -35,12 +35,12 @@ class PusherFactory:
"http": HttpPusher
}
- logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
- if hs.config.email_enable_notifs:
+ logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs)
+ if hs.config.email.email_enable_notifs:
self.mailers: Dict[str, Mailer] = {}
- self._notif_template_html = hs.config.email_notif_template_html
- self._notif_template_text = hs.config.email_notif_template_text
+ self._notif_template_html = hs.config.email.email_notif_template_html
+ self._notif_template_text = hs.config.email.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index a1436f3930..26735447a6 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool:
self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID.
- self._pusher_shard_config = hs.config.push.pusher_shard_config
+ self._pusher_shard_config = hs.config.worker.pusher_shard_config
self._instance_name = hs.get_instance_name()
self._should_start_pushers = (
self._instance_name in self._pusher_shard_config.instances
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3adc576124..e04af705eb 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import JsonResource
+from typing import TYPE_CHECKING
+
+from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
from synapse.rest.client import (
account,
@@ -57,6 +59,9 @@ from synapse.rest.client import (
voip,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
@@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
- def register_servlets(client_resource, hs):
+ def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource)
# Deprecated in r0
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 5715190a78..a6fa03c90f 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore()
async def on_GET(
- self, request: SynapseRequest, user_id, device_id: str
+ self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ad83d4b54c..8f781f745f 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -125,7 +125,7 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- search_term = parse_string(request, "search_term")
+ search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
raise SynapseError(
400,
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index f5a38c2670..19f84f33f2 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, json_resource: HttpServer):
+ def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index c1a1ba645e..681e491826 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {}
self.hs = hs
- def _clear_old_nonces(self):
+ def _clear_old_nonces(self) -> None:
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index ed96978448..bf14ec384e 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -14,6 +14,7 @@
import logging
import re
+from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, List, Tuple
from twisted.web.server import Request
@@ -42,25 +43,25 @@ logger = logging.getLogger(__name__)
class RoomBatchSendEventRestServlet(RestServlet):
"""
- API endpoint which can insert a chunk of events historically back in time
+ API endpoint which can insert a batch of events historically back in time
next to the given `prev_event`.
- `chunk_id` comes from `next_chunk_id `in the response of the batch send
- endpoint and is derived from the "insertion" events added to each chunk.
+ `batch_id` comes from `next_batch_id `in the response of the batch send
+ endpoint and is derived from the "insertion" events added to each batch.
It's not required for the first batch send.
`state_events_at_start` is used to define the historical state events
needed to auth the events like join events. These events will float
outside of the normal DAG as outlier's and won't be visible in the chat
- history which also allows us to insert multiple chunks without having a bunch
- of `@mxid joined the room` noise between each chunk.
+ history which also allows us to insert multiple batches without having a bunch
+ of `@mxid joined the room` noise between each batch.
- `events` is chronological chunk/list of events you want to insert.
- There is a reverse-chronological constraint on chunks so once you insert
+ `events` is chronological list of events you want to insert.
+ There is a reverse-chronological constraint on batches so once you insert
some messages, you can only insert older ones after that.
- tldr; Insert chunks from your most recent history -> oldest history.
+ tldr; Insert batches from your most recent history -> oldest history.
- POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event=<eventID>&chunk_id=<chunkID>
+ POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event_id=<eventID>&batch_id=<batchID>
{
"events": [ ... ],
"state_events_at_start": [ ... ]
@@ -128,7 +129,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self, sender: str, room_id: str, origin_server_ts: int
) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields
- and a random chunk ID.
+ and a random batch ID.
Args:
sender: The event author MXID
@@ -139,13 +140,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
The new event dictionary to insert.
"""
- next_chunk_id = random_string(8)
+ next_batch_id = random_string(8)
insertion_event = {
"type": EventTypes.MSC2716_INSERTION,
"sender": sender,
"room_id": room_id,
"content": {
- EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
+ EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
EventContentFields.MSC2716_HISTORICAL: True,
},
"origin_server_ts": origin_server_ts,
@@ -179,7 +180,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
if not requester.app_service:
raise AuthError(
- 403,
+ HTTPStatus.FORBIDDEN,
"Only application services can use the /batchsend endpoint",
)
@@ -187,24 +188,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
assert_params_in_dict(body, ["state_events_at_start", "events"])
assert request.args is not None
- prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
- chunk_id_from_query = parse_string(request, "chunk_id")
+ prev_event_ids_from_query = parse_strings_from_args(
+ request.args, "prev_event_id"
+ )
+ batch_id_from_query = parse_string(request, "batch_id")
- if prev_events_from_query is None:
+ if prev_event_ids_from_query is None:
raise SynapseError(
- 400,
+ HTTPStatus.BAD_REQUEST,
"prev_event query parameter is required when inserting historical messages back in time",
errcode=Codes.MISSING_PARAM,
)
- # For the event we are inserting next to (`prev_events_from_query`),
+ # For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base
# to auth our historical messages against.
(
most_recent_prev_event_id,
_,
- ) = await self.store.get_max_depth_of(prev_events_from_query)
+ ) = await self.store.get_max_depth_of(prev_event_ids_from_query)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id
@@ -213,7 +216,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
- state_events_at_start = []
+ state_event_ids_at_start = []
for state_event in body["state_events_at_start"]:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@@ -279,27 +282,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
event_id = event.event_id
- state_events_at_start.append(event_id)
+ state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
events_to_create = body["events"]
inherited_depth = await self._inherit_depth_from_prev_ids(
- prev_events_from_query
+ prev_event_ids_from_query
)
- # Figure out which chunk to connect to. If they passed in
- # chunk_id_from_query let's use it. The chunk ID passed in comes
- # from the chunk_id in the "insertion" event from the previous chunk.
- last_event_in_chunk = events_to_create[-1]
- chunk_id_to_connect_to = chunk_id_from_query
+ # Figure out which batch to connect to. If they passed in
+ # batch_id_from_query let's use it. The batch ID passed in comes
+ # from the batch_id in the "insertion" event from the previous batch.
+ last_event_in_batch = events_to_create[-1]
+ batch_id_to_connect_to = batch_id_from_query
base_insertion_event = None
- if chunk_id_from_query:
+ if batch_id_from_query:
# All but the first base insertion event should point at a fake
# event, which causes the HS to ask for the state at the start of
- # the chunk later.
+ # the batch later.
prev_event_ids = [fake_prev_event_id]
- # TODO: Verify the chunk_id_from_query corresponds to an insertion event
+
+ # Verify the batch_id_from_query corresponds to an actual insertion event
+ # and have the batch connected.
+ corresponding_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(batch_id_from_query)
+ )
+ if corresponding_insertion_event_id is None:
+ raise SynapseError(
+ 400,
+ "No insertion event corresponds to the given ?batch_id",
+ errcode=Codes.INVALID_PARAM,
+ )
pass
# Otherwise, create an insertion event to act as a starting point.
#
@@ -309,12 +323,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
# an insertion event), in which case we just create a new insertion event
# that can then get pointed to by a "marker" event later.
else:
- prev_event_ids = prev_events_from_query
+ prev_event_ids = prev_event_ids_from_query
base_insertion_event_dict = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
- origin_server_ts=last_event_in_chunk["origin_server_ts"],
+ origin_server_ts=last_event_in_batch["origin_server_ts"],
)
base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
@@ -333,38 +347,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth,
)
- chunk_id_to_connect_to = base_insertion_event["content"][
- EventContentFields.MSC2716_NEXT_CHUNK_ID
+ batch_id_to_connect_to = base_insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
]
- # Connect this current chunk to the insertion event from the previous chunk
- chunk_event = {
- "type": EventTypes.MSC2716_CHUNK,
+ # Connect this current batch to the insertion event from the previous batch
+ batch_event = {
+ "type": EventTypes.MSC2716_BATCH,
"sender": requester.user.to_string(),
"room_id": room_id,
"content": {
- EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to,
+ EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
EventContentFields.MSC2716_HISTORICAL: True,
},
- # Since the chunk event is put at the end of the chunk,
+ # Since the batch event is put at the end of the batch,
# where the newest-in-time event is, copy the origin_server_ts from
# the last event we're inserting
- "origin_server_ts": last_event_in_chunk["origin_server_ts"],
+ "origin_server_ts": last_event_in_batch["origin_server_ts"],
}
- # Add the chunk event to the end of the chunk (newest-in-time)
- events_to_create.append(chunk_event)
+ # Add the batch event to the end of the batch (newest-in-time)
+ events_to_create.append(batch_event)
- # Add an "insertion" event to the start of each chunk (next to the oldest-in-time
- # event in the chunk) so the next chunk can be connected to this one.
+ # Add an "insertion" event to the start of each batch (next to the oldest-in-time
+ # event in the batch) so the next batch can be connected to this one.
insertion_event = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
- # Since the insertion event is put at the start of the chunk,
+ # Since the insertion event is put at the start of the batch,
# where the oldest-in-time event is, copy the origin_server_ts from
# the first event we're inserting
origin_server_ts=events_to_create[0]["origin_server_ts"],
)
- # Prepend the insertion event to the start of the chunk (oldest-in-time)
+ # Prepend the insertion event to the start of the batch (oldest-in-time)
events_to_create = [insertion_event] + events_to_create
event_ids = []
@@ -424,20 +438,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
context=context,
)
- # Add the base_insertion_event to the bottom of the list we return
- if base_insertion_event is not None:
- event_ids.append(base_insertion_event.event_id)
+ insertion_event_id = event_ids[0]
+ batch_event_id = event_ids[-1]
+ historical_event_ids = event_ids[1:-1]
- return 200, {
- "state_events": state_events_at_start,
- "events": event_ids,
- "next_chunk_id": insertion_event["content"][
- EventContentFields.MSC2716_NEXT_CHUNK_ID
+ response_dict = {
+ "state_event_ids": state_event_ids_at_start,
+ "event_ids": historical_event_ids,
+ "next_batch_id": insertion_event["content"][
+ EventContentFields.MSC2716_NEXT_BATCH_ID
],
+ "insertion_event_id": insertion_event_id,
+ "batch_event_id": batch_event_id,
}
+ if base_insertion_event is not None:
+ response_dict["base_insertion_event_id"] = base_insertion_event.event_id
+
+ return HTTPStatus.OK, response_dict
def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
- return 501, "Not implemented"
+ return HTTPStatus.NOT_IMPLEMENTED, "Not implemented"
def on_PUT(
self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 11f7320832..06e0fbde22 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -17,17 +17,22 @@ import logging
from hashlib import sha256
from http import HTTPStatus
from os import path
-from typing import Dict, List
+from typing import TYPE_CHECKING, Any, Dict, List
import jinja2
from jinja2 import TemplateNotFound
+from twisted.web.server import Request
+
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en"
@@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user.
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): homeserver
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
- async def _async_render_GET(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
if not public_version:
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- async def _async_render_POST(self, request):
- """
- Args:
- request (twisted.web.http.Request):
- """
+ async def _async_render_POST(self, request: Request) -> None:
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
- args: Dict[bytes, List[bytes]] = request.args
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)
@@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("success.html not found")
- def _render_template(self, request, template_name, **template_args):
+ def _render_template(
+ self, request: Request, template_name: str, **template_args: Any
+ ) -> None:
# get_template checks for ".." so we don't need to worry too much
# about path traversal here.
template_html = self._jinja_env.get_template(
@@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args)
respond_with_html(request, 200, html)
- def _check_hash(self, userid, userhmac):
+ def _check_hash(self, userid: str, userhmac: bytes) -> None:
"""
Args:
- userid (unicode):
- userhmac (bytes):
+ userid:
+ userhmac:
Raises:
SynapseError if the hash doesn't match
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
index 4487b54abf..78df7af2cf 100644
--- a/synapse/rest/health.py
+++ b/synapse/rest/health.py
@@ -13,6 +13,7 @@
# limitations under the License.
from twisted.web.resource import Resource
+from twisted.web.server import Request
class HealthResource(Resource):
@@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain")
return b"OK"
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index c6c63073ea..7f8c1de1ff 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
+
from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
class KeyApiV2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 25f6eb842f..ebe243bcfd 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
+from typing import TYPE_CHECKING
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -58,18 +63,18 @@ class LocalKey(Resource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
- def update_response_body(self, time_now_msec):
+ def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
- def response_json_object(self):
+ def response_json_object(self) -> JsonDict:
verify_keys = {}
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
@@ -94,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 744360e5fd..d8fd7938a4 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,17 +13,23 @@
# limitations under the License.
import logging
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json
+from twisted.web.server import Request
+
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
@@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
+ assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
@@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: Request) -> None:
content = parse_json_object_from_request(request)
query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- async def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ async def query_keys(
+ self,
+ request: Request,
+ query: JsonDict,
+ query_remote_on_cache_miss: bool = False,
+ ) -> None:
logger.info("Handling query for keys %r", query)
store_queries = []
@@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource):
# Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), results in cached.items():
- results = [(result["ts_added_ms"], result) for result in results]
+ for (server_name, key_id, _), key_results in cached.items():
+ results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0
@@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json)
- results = {"server_keys": signed_keys}
+ response = {"server_keys": signed_keys}
- respond_with_json(request, 200, results, canonical_json=True)
+ respond_with_json(request, 200, response, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 90364ebcf7..7c881f2bdb 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,7 +16,10 @@
import logging
import os
import urllib
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple
+from types import TracebackType
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+
+import attr
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
@@ -120,7 +123,7 @@ def add_file_headers(
upload_name: The name of the requested file, if any.
"""
- def _quote(x):
+ def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types.
@@ -280,51 +283,74 @@ class Responder:
"""
pass
- def __enter__(self):
+ def __enter__(self) -> None:
pass
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
pass
-class FileInfo:
- """Details about a requested/uploaded file.
-
- Attributes:
- server_name (str): The server name where the media originated from,
- or None if local.
- file_id (str): The local ID of the file. For local files this is the
- same as the media_id
- url_cache (bool): If the file is for the url preview cache
- thumbnail (bool): Whether the file is a thumbnail or not.
- thumbnail_width (int)
- thumbnail_height (int)
- thumbnail_method (str)
- thumbnail_type (str): Content type of thumbnail, e.g. image/png
- thumbnail_length (int): The size of the media file, in bytes.
- """
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThumbnailInfo:
+ """Details about a generated thumbnail."""
- def __init__(
- self,
- server_name,
- file_id,
- url_cache=False,
- thumbnail=False,
- thumbnail_width=None,
- thumbnail_height=None,
- thumbnail_method=None,
- thumbnail_type=None,
- thumbnail_length=None,
- ):
- self.server_name = server_name
- self.file_id = file_id
- self.url_cache = url_cache
- self.thumbnail = thumbnail
- self.thumbnail_width = thumbnail_width
- self.thumbnail_height = thumbnail_height
- self.thumbnail_method = thumbnail_method
- self.thumbnail_type = thumbnail_type
- self.thumbnail_length = thumbnail_length
+ width: int
+ height: int
+ method: str
+ # Content type of thumbnail, e.g. image/png
+ type: str
+ # The size of the media file, in bytes.
+ length: Optional[int] = None
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FileInfo:
+ """Details about a requested/uploaded file."""
+
+ # The server name where the media originated from, or None if local.
+ server_name: Optional[str]
+ # The local ID of the file. For local files this is the same as the media_id
+ file_id: str
+ # If the file is for the url preview cache
+ url_cache: bool = False
+ # Whether the file is a thumbnail or not.
+ thumbnail: Optional[ThumbnailInfo] = None
+
+ # The below properties exist to maintain compatibility with third-party modules.
+ @property
+ def thumbnail_width(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.width
+
+ @property
+ def thumbnail_height(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.height
+
+ @property
+ def thumbnail_method(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.method
+
+ @property
+ def thumbnail_type(self) -> Optional[str]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.type
+
+ @property
+ def thumbnail_length(self) -> Optional[int]:
+ if not self.thumbnail:
+ return None
+ return self.thumbnail.length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 09531ebf54..39bbe4e874 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -16,7 +16,7 @@
import functools
import os
import re
-from typing import Callable, List
+from typing import Any, Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
"""
@functools.wraps(func)
- def _wrapped(self, *args, **kwargs):
+ def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path)
@@ -129,7 +129,7 @@ class MediaFilePaths:
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
- ):
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0f5ce41ff8..50e4c9e29f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -32,6 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
+from synapse.config.repository import ThumbnailRequirement
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
@@ -42,6 +44,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
Responder,
+ ThumbnailInfo,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -113,7 +116,7 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
- def _start_update_recently_accessed(self):
+ def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
)
@@ -210,7 +213,7 @@ class MediaRepository:
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
- file_info = FileInfo(None, media_id, url_cache=url_cache)
+ file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
@@ -468,7 +471,9 @@ class MediaRepository:
return media_info
- def _get_thumbnail_requirements(self, media_type):
+ def _get_thumbnail_requirements(
+ self, media_type: str
+ ) -> Tuple[ThumbnailRequirement, ...]:
scpos = media_type.find(";")
if scpos > 0:
media_type = media_type[:scpos]
@@ -514,7 +519,7 @@ class MediaRepository:
t_height: int,
t_method: str,
t_type: str,
- url_cache: Optional[str],
+ url_cache: bool,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
@@ -548,11 +553,12 @@ class MediaRepository:
server_name=None,
file_id=media_id,
url_cache=url_cache,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -585,7 +591,7 @@ class MediaRepository:
t_type: str,
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
- FileInfo(server_name, file_id, url_cache=False)
+ FileInfo(server_name, file_id)
)
try:
@@ -616,11 +622,12 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
output_path = await self.media_storage.store_file(
@@ -742,12 +749,13 @@ class MediaRepository:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
url_cache=url_cache,
+ thumbnail=ThumbnailInfo(
+ width=t_width,
+ height=t_height,
+ method=t_method,
+ type=t_type,
+ ),
)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 56cdc1b4ed..01fada8fb5 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -15,7 +15,20 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+from types import TracebackType
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ BinaryIO,
+ Callable,
+ Generator,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+)
import attr
@@ -83,12 +96,14 @@ class MediaStorage:
return fname
- async def write_to_file(self, source: IO, output: IO):
+ async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
- def store_into_file(self, file_info: FileInfo):
+ def store_into_file(
+ self, file_info: FileInfo
+ ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -125,7 +140,7 @@ class MediaStorage:
try:
with open(fname, "wb") as f:
- async def finish():
+ async def finish() -> None:
# Ensure that all writes have been flushed and close the
# file.
f.flush()
@@ -176,9 +191,9 @@ class MediaStorage:
self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
)
@@ -220,9 +235,9 @@ class MediaStorage:
legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
)
legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
if os.path.exists(legacy_local_path):
@@ -255,10 +270,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.url_cache_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
@@ -267,10 +282,10 @@ class MediaStorage:
return self.filepaths.remote_media_thumbnail_rel(
server_name=file_info.server_name,
file_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.remote_media_filepath_rel(
file_info.server_name, file_info.file_id
@@ -279,10 +294,10 @@ class MediaStorage:
if file_info.thumbnail:
return self.filepaths.local_media_thumbnail_rel(
media_id=file_info.file_id,
- width=file_info.thumbnail_width,
- height=file_info.thumbnail_height,
- content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method,
+ width=file_info.thumbnail.width,
+ height=file_info.thumbnail.height,
+ content_type=file_info.thumbnail.type,
+ method=file_info.thumbnail.method,
)
return self.filepaths.local_media_filepath_rel(file_info.file_id)
@@ -315,7 +330,12 @@ class FileResponder(Responder):
FileSender().beginFileTransfer(self.open_file, consumer)
)
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> None:
self.open_file.close()
@@ -339,7 +359,7 @@ class ReadableFileWrapper:
clock = attr.ib(type=Clock)
path = attr.ib(type=str)
- async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2e6706dbfa..8b74e72655 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import urllib.parse
from typing import TYPE_CHECKING, Optional
import attr
from synapse.http.client import SimpleHttpClient
+from synapse.types import JsonDict
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -24,18 +27,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-@attr.s(slots=True, auto_attribs=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class OEmbedResult:
- # Either HTML content or URL must be provided.
- html: Optional[str]
- url: Optional[str]
- title: Optional[str]
- # Number of seconds to cache the content.
- cache_age: int
-
-
-class OEmbedError(Exception):
- """An error occurred processing the oEmbed object."""
+ # The Open Graph result (converted from the oEmbed result).
+ open_graph_result: JsonDict
+ # Number of seconds to cache the content, according to the oEmbed response.
+ #
+ # This will be None if no cache-age is provided in the oEmbed response (or
+ # if the oEmbed response cannot be turned into an Open Graph response).
+ cache_age: Optional[int]
class OEmbedProvider:
@@ -81,75 +81,106 @@ class OEmbedProvider:
"""
for url_pattern, endpoint in self._oembed_patterns.items():
if url_pattern.fullmatch(url):
- return endpoint
+ # TODO Specify max height / width.
+
+ # Note that only the JSON format is supported, some endpoints want
+ # this in the URL, others want it as an argument.
+ endpoint = endpoint.replace("{format}", "json")
+
+ args = {"url": url, "format": "json"}
+ query_str = urllib.parse.urlencode(args, True)
+ return f"{endpoint}?{query_str}"
# No match.
return None
- async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
"""
- Request content from an oEmbed endpoint.
+ Parse the oEmbed response into an Open Graph response.
Args:
- endpoint: The oEmbed API endpoint.
- url: The URL to pass to the API.
+ url: The URL which is being previewed (not the one which was
+ requested).
+ raw_body: The oEmbed response as JSON encoded as bytes.
Returns:
- An object representing the metadata returned.
-
- Raises:
- OEmbedError if fetching or parsing of the oEmbed information fails.
+ json-encoded Open Graph data
"""
- try:
- logger.debug("Trying to get oEmbed content for url '%s'", url)
- # Note that only the JSON format is supported, some endpoints want
- # this in the URL, others want it as an argument.
- endpoint = endpoint.replace("{format}", "json")
-
- result = await self._client.get_json(
- endpoint,
- # TODO Specify max height / width.
- args={"url": url, "format": "json"},
- )
+ try:
+ # oEmbed responses *must* be UTF-8 according to the spec.
+ oembed = json_decoder.decode(raw_body.decode("utf-8"))
# Ensure there's a version of 1.0.
- if result.get("version") != "1.0":
- raise OEmbedError("Invalid version: %s" % (result.get("version"),))
-
- oembed_type = result.get("type")
+ oembed_version = oembed["version"]
+ if oembed_version != "1.0":
+ raise RuntimeError(f"Invalid version: {oembed_version}")
# Ensure the cache age is None or an int.
- cache_age = result.get("cache_age")
+ cache_age = oembed.get("cache_age")
if cache_age:
cache_age = int(cache_age)
- oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+ # The results.
+ open_graph_response = {"og:title": oembed.get("title")}
- # HTML content.
+ # If a thumbnail exists, use it. Note that dimensions will be calculated later.
+ if "thumbnail_url" in oembed:
+ open_graph_response["og:image"] = oembed["thumbnail_url"]
+
+ # Process each type separately.
+ oembed_type = oembed["type"]
if oembed_type == "rich":
- oembed_result.html = result.get("html")
- return oembed_result
+ calc_description_and_urls(open_graph_response, oembed["html"])
- if oembed_type == "photo":
- oembed_result.url = result.get("url")
- return oembed_result
+ elif oembed_type == "photo":
+ # If this is a photo, use the full image, not the thumbnail.
+ open_graph_response["og:image"] = oembed["url"]
- # TODO Handle link and video types.
+ else:
+ raise RuntimeError(f"Unknown oEmbed type: {oembed_type}")
- if "thumbnail_url" in result:
- oembed_result.url = result.get("thumbnail_url")
- return oembed_result
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ logger.warning(f"Error parsing oEmbed metadata from {url}: {e:r}")
+ open_graph_response = {}
+ cache_age = None
- raise OEmbedError("Incompatible oEmbed information.")
+ return OEmbedResult(open_graph_response, cache_age)
- except OEmbedError as e:
- # Trap OEmbedErrors first so we can directly re-raise them.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- raise
- except Exception as e:
- # Trap any exception and let the code follow as usual.
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
- raise OEmbedError() from e
+def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None:
+ """
+ Calculate description for an HTML document.
+
+ This uses lxml to convert the HTML document into plaintext. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ open_graph_response: The current Open Graph summary. This is updated with additional fields.
+ html_body: The HTML document, as bytes.
+
+ Returns:
+ The summary
+ """
+ # If there's no body, nothing useful is going to be found.
+ if not html_body:
+ return
+
+ from lxml import etree
+
+ # Create an HTML parser. If this fails, log and return no metadata.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(html_body, parser)
+
+ # The data was successfully parsed, but no tree was found.
+ if tree is None:
+ return
+
+ from synapse.rest.media.v1.preview_url_resource import _calc_description
+
+ description = _calc_description(tree)
+ if description:
+ open_graph_response["og:description"] = description
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f108da05db..0a0b476d2b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,6 +27,7 @@ from urllib import parse as urlparse
import attr
+from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
@@ -43,7 +44,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
-from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -72,6 +73,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
ONE_HOUR = 60 * 60 * 1000
+ONE_DAY = 24 * ONE_HOUR
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -254,10 +256,19 @@ class PreviewUrlResource(DirectServeJsonResource):
og = og.encode("utf8")
return og
- media_info = await self._download_url(url, user)
+ # If this URL can be accessed via oEmbed, use that instead.
+ url_to_download = url
+ oembed_url = self._oembed.get_oembed_url(url)
+ if oembed_url:
+ url_to_download = oembed_url
+
+ media_info = await self._download_url(url_to_download, user)
logger.debug("got media_info of '%s'", media_info)
+ # The number of milliseconds that the response should be considered valid.
+ expiration_ms = media_info.expires
+
if _is_media(media_info.media_type):
file_id = media_info.filesystem_id
dims = await self.media_repo._generate_thumbnails(
@@ -287,34 +298,22 @@ class PreviewUrlResource(DirectServeJsonResource):
encoding = get_html_media_encoding(body, media_info.media_type)
og = decode_and_calc_og(body, media_info.uri, encoding)
- # pre-cache the image for posterity
- # FIXME: it might be cleaner to use the same flow as the main /preview_url
- # request itself and benefit from the same caching etc. But for now we
- # just rely on the caching on the master request to speed things up.
- if "og:image" in og and og["og:image"]:
- image_info = await self._download_url(
- _rebase_url(og["og:image"], media_info.uri), user
- )
+ await self._precache_image_url(user, media_info, og)
+
+ elif oembed_url and _is_json(media_info.media_type):
+ # Handle an oEmbed response.
+ with open(media_info.filename, "rb") as file:
+ body = file.read()
+
+ oembed_response = self._oembed.parse_oembed_response(media_info.uri, body)
+ og = oembed_response.open_graph_result
+
+ # Use the cache age from the oEmbed result, instead of the HTTP response.
+ if oembed_response.cache_age is not None:
+ expiration_ms = oembed_response.cache_age
+
+ await self._precache_image_url(user, media_info, og)
- if _is_media(image_info.media_type):
- # TODO: make sure we don't choke on white-on-transparent images
- file_id = image_info.filesystem_id
- dims = await self.media_repo._generate_thumbnails(
- None, file_id, file_id, image_info.media_type, url_cache=True
- )
- if dims:
- og["og:image:width"] = dims["width"]
- og["og:image:height"] = dims["height"]
- else:
- logger.warning("Couldn't get dims for %s", og["og:image"])
-
- og[
- "og:image"
- ] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
- og["og:image:type"] = image_info.media_type
- og["matrix:image:size"] = image_info.media_length
- else:
- del og["og:image"]
else:
logger.warning("Failed to find any OG data in %s", url)
og = {}
@@ -335,12 +334,15 @@ class PreviewUrlResource(DirectServeJsonResource):
jsonog = json_encoder.encode(og)
+ # Cap the amount of time to consider a response valid.
+ expiration_ms = min(expiration_ms, ONE_DAY)
+
# store OG in history-aware DB cache
await self.store.store_url_cache(
url,
media_info.response_code,
media_info.etag,
- media_info.expires + media_info.created_ts_ms,
+ media_info.created_ts_ms + expiration_ms,
jsonog,
media_info.filesystem_id,
media_info.created_ts_ms,
@@ -357,88 +359,52 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- # If this URL can be accessed via oEmbed, use that instead.
- url_to_download: Optional[str] = url
- oembed_url = self._oembed.get_oembed_url(url)
- if oembed_url:
- # The result might be a new URL to download, or it might be HTML content.
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- oembed_result = await self._oembed.get_oembed_content(oembed_url, url)
- if oembed_result.url:
- url_to_download = oembed_result.url
- elif oembed_result.html:
- url_to_download = None
- except OEmbedError:
- # If an error occurs, try doing a normal preview.
- pass
+ logger.debug("Trying to get preview for url '%s'", url)
+ length, headers, uri, code = await self.client.get_file(
+ url,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url, e)
- if url_to_download:
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- try:
- logger.debug("Trying to get preview for url '%s'", url_to_download)
- length, headers, uri, code = await self.client.get_file(
- url_to_download,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url_to_download, e)
-
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
-
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
- else:
- media_type = "application/octet-stream"
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+ await finish()
- download_name = get_filename_from_headers(headers)
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- expires = ONE_HOUR
- etag = (
- headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
- )
- else:
- # we can only get here if we did an oembed request and have an oembed_result.html
- assert oembed_result.html is not None
- assert oembed_url is not None
-
- html_bytes = oembed_result.html.encode("utf-8")
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- f.write(html_bytes)
- await finish()
-
- media_type = "text/html"
- download_name = oembed_result.title
- length = len(html_bytes)
- # If a specific cache age was not given, assume 1 hour.
- expires = oembed_result.cache_age or ONE_HOUR
- uri = oembed_url
- code = 200
- etag = None
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
try:
time_now_ms = self.clock.time_msec()
@@ -473,7 +439,47 @@ class PreviewUrlResource(DirectServeJsonResource):
etag=etag,
)
- def _start_expire_url_cache_data(self):
+ async def _precache_image_url(
+ self, user: str, media_info: MediaInfo, og: JsonDict
+ ) -> None:
+ """
+ Pre-cache the image (if one exists) for posterity
+
+ Args:
+ user: The user requesting the preview.
+ media_info: The media being previewed.
+ og: The Open Graph dictionary. This is modified with image information.
+ """
+ # If there's no image or it is blank, there's nothing to do.
+ if "og:image" not in og or not og["og:image"]:
+ return
+
+ # FIXME: it might be cleaner to use the same flow as the main /preview_url
+ # request itself and benefit from the same caching etc. But for now we
+ # just rely on the caching on the master request to speed things up.
+ image_info = await self._download_url(
+ _rebase_url(og["og:image"], media_info.uri), user
+ )
+
+ if _is_media(image_info.media_type):
+ # TODO: make sure we don't choke on white-on-transparent images
+ file_id = image_info.filesystem_id
+ dims = await self.media_repo._generate_thumbnails(
+ None, file_id, file_id, image_info.media_type, url_cache=True
+ )
+ if dims:
+ og["og:image:width"] = dims["width"]
+ og["og:image:height"] = dims["height"]
+ else:
+ logger.warning("Couldn't get dims for %s", og["og:image"])
+
+ og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
+ og["og:image:type"] = image_info.media_type
+ og["matrix:image:size"] = image_info.media_length
+ else:
+ del og["og:image"]
+
+ def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
)
@@ -526,7 +532,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * 24 * ONE_HOUR
+ expire_before = now - 2 * ONE_DAY
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
@@ -668,7 +674,18 @@ def decode_and_calc_og(
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
- # suck our tree into lxml and define our OG response.
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to search the HTML document for Open Graph data.
+
+ Args:
+ tree: The parsed HTML document.
+ media_url: The URI used to download the body.
+
+ Returns:
+ The Open Graph response as a dictionary.
+ """
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
@@ -742,35 +759,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
if meta_description:
og["og:description"] = meta_description[0]
else:
- # grab any text nodes which are inside the <body/> tag...
- # unless they are within an HTML5 semantic markup tag...
- # <header/>, <nav/>, <aside/>, <footer/>
- # ...or if they are within a <script/> or <style/> tag.
- # This is a very very very coarse approximation to a plain text
- # render of the page.
-
- # We don't just use XPATH here as that is slow on some machines.
-
- from lxml import etree
-
- TAGS_TO_REMOVE = (
- "header",
- "nav",
- "aside",
- "footer",
- "script",
- "noscript",
- "style",
- etree.Comment,
- )
-
- # Split all the text nodes into paragraphs (by splitting on new
- # lines)
- text_nodes = (
- re.sub(r"\s+", "\n", el).strip()
- for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
- )
- og["og:description"] = summarize_paragraphs(text_nodes)
+ og["og:description"] = _calc_description(tree)
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
@@ -781,8 +770,48 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
return og
+def _calc_description(tree: "etree.Element") -> Optional[str]:
+ """
+ Calculate a text description based on an HTML document.
+
+ Grabs any text nodes which are inside the <body/> tag, unless they are within
+ an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
+ if they are within a <script/> or <style/> tag.
+
+ This is a very very very coarse approximation to a plain text render of the page.
+
+ Args:
+ tree: The parsed HTML document.
+
+ Returns:
+ The plain text description, or None if one cannot be generated.
+ """
+ # We don't just use XPATH here as that is slow on some machines.
+
+ from lxml import etree
+
+ TAGS_TO_REMOVE = (
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment,
+ )
+
+ # Split all the text nodes into paragraphs (by splitting on new
+ # lines)
+ text_nodes = (
+ re.sub(r"\s+", "\n", el).strip()
+ for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+ )
+ return summarize_paragraphs(text_nodes)
+
+
def _iterate_over_text(
- tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+ tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
@@ -840,11 +869,25 @@ def _is_html(content_type: str) -> bool:
)
+def _is_json(content_type: str) -> bool:
+ return content_type.lower().startswith("application/json")
+
+
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
- # Try to get a summary of between 200 and 500 words, respecting
- # first paragraph and then word boundaries.
+ """
+ Try to get a summary respecting first paragraph and then word boundaries.
+
+ Args:
+ text_nodes: The paragraphs to summarize.
+ min_size: The minimum number of words to include.
+ max_size: The maximum number of words to include.
+
+ Returns:
+ A summary of the text nodes, or None if that was not possible.
+ """
+
# TODO: Respect sentences?
description = ""
@@ -867,7 +910,7 @@ def summarize_paragraphs(
new_desc = ""
# This splits the paragraph into words, but keeping the
- # (preceeding) whitespace intact so we can easily concat
+ # (preceding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer(r"\s*\S+", description):
word = match.group()
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 0ff6ad3c0c..6c9969e55f 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider):
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
- async def store():
+ async def store() -> None:
try:
return await maybe_awaitable(
self.backend.store_file(path, file_info)
@@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
- def __str__(self):
+ def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None:
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 12bd745cb2..22f43d8531 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -26,6 +26,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
+ ThumbnailInfo,
parse_media_id,
respond_404,
respond_with_file,
@@ -114,7 +115,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_id,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
server_name=None,
)
@@ -149,11 +150,12 @@ class ThumbnailResource(DirectServeJsonResource):
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -173,7 +175,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height,
desired_method,
desired_type,
- url_cache=media_info["url_cache"],
+ url_cache=bool(media_info["url_cache"]),
)
if file_path:
@@ -210,11 +212,12 @@ class ThumbnailResource(DirectServeJsonResource):
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=info["thumbnail_width"],
- thumbnail_height=info["thumbnail_height"],
- thumbnail_type=info["thumbnail_type"],
- thumbnail_method=info["thumbnail_method"],
+ thumbnail=ThumbnailInfo(
+ width=info["thumbnail_width"],
+ height=info["thumbnail_height"],
+ type=info["thumbnail_type"],
+ method=info["thumbnail_method"],
+ ),
)
t_type = file_info.thumbnail_type
@@ -271,7 +274,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos,
media_id,
media_info["filesystem_id"],
- url_cache=None,
+ url_cache=False,
server_name=server_name,
)
@@ -285,7 +288,7 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos: List[Dict[str, Any]],
media_id: str,
file_id: str,
- url_cache: Optional[str] = None,
+ url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
@@ -299,7 +302,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos:
@@ -318,13 +321,16 @@ class ThumbnailResource(DirectServeJsonResource):
respond_404(request)
return
+ # The thumbnail property must exist.
+ assert file_info.thumbnail is not None
+
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
return
@@ -351,18 +357,18 @@ class ThumbnailResource(DirectServeJsonResource):
server_name,
file_id=file_id,
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
- t_width=file_info.thumbnail_width,
- t_height=file_info.thumbnail_height,
- t_method=file_info.thumbnail_method,
- t_type=file_info.thumbnail_type,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
@@ -370,8 +376,8 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_responder(
request,
responder,
- file_info.thumbnail_type,
- file_info.thumbnail_length,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
)
else:
logger.info("Failed to find any generated thumbnails")
@@ -385,7 +391,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
- url_cache: Optional[str],
+ url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
@@ -398,7 +404,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: The URL cache value.
+ url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
@@ -495,12 +501,13 @@ class ThumbnailResource(DirectServeJsonResource):
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- thumbnail_length=thumbnail_info["thumbnail_length"],
+ thumbnail=ThumbnailInfo(
+ width=thumbnail_info["thumbnail_width"],
+ height=thumbnail_info["thumbnail_height"],
+ type=thumbnail_info["thumbnail_type"],
+ method=thumbnail_info["thumbnail_method"],
+ length=thumbnail_info["thumbnail_length"],
+ ),
)
# No matching thumbnail was found.
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a65e9e1802..df54a40649 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -41,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod
- def set_limits(max_image_pixels: int):
+ def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index 67c1ed1f5f..1c1c7b3613 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generator
from twisted.web.server import Request
@@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._server_name = hs.hostname
self._consent_version = hs.config.consent.user_consent_version
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: Request):
+ async def _async_render_POST(self, request: Request) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index 36ba401656..81fec39659 100644
--- a/synapse/rest/synapse/client/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -13,16 +13,20 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class OIDCResource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index 7785f17e90..4f375cb74c 100644
--- a/synapse/rest/synapse/client/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -30,10 +31,10 @@ class OIDCCallbackResource(DirectServeHtmlResource):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
await self._oidc_handler.handle_oidc_callback(request)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# the auth response can be returned via an x-www-form-urlencoded form instead
# of GET params, as per
# https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d30b478b98..28ae083497 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, Generator, List, Tuple
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -27,6 +27,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_boolean, parse_string
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util.templates import build_jinja_env
if TYPE_CHECKING:
@@ -57,7 +58,7 @@ class AvailabilityCheckResource(DirectServeJsonResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request: Request):
+ async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]:
localpart = parse_string(request, "username", required=True)
session_id = get_username_mapping_session_cookie_from_request(request)
@@ -73,7 +74,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- def template_search_dirs():
+ def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@@ -104,7 +105,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: SynapseRequest):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None
diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 781ccb237c..3f247e6a2c 100644
--- a/synapse/rest/synapse/client/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -13,17 +13,21 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class SAML2Resource(Resource):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))
diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index b37c7083dc..64378ed57b 100644
--- a/synapse/rest/synapse/client/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
@@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING
import saml2.metadata
from twisted.web.resource import Resource
+from twisted.web.server import Request
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SAML2MetadataResource(Resource):
@@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.sp_config = hs.config.saml2_sp_config
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
metadata_xml = saml2.metadata.create_metadata_string(
configfile=None, config=self.sp_config
)
diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index 774ccd870f..47d2a6a229 100644
--- a/synapse/rest/synapse/client/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
@@ -15,7 +15,10 @@
from typing import TYPE_CHECKING
+from twisted.web.server import Request
+
from synapse.http.server import DirectServeHtmlResource
+from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -31,7 +34,7 @@ class SAML2ResponseResource(DirectServeHtmlResource):
self._saml_handler = hs.get_saml_handler()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
@@ -40,5 +43,5 @@ class SAML2ResponseResource(DirectServeHtmlResource):
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 6a66a88c53..c80a3a99aa 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,26 +13,26 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Optional
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse.http.server import set_cors_headers
+from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class WellKnownBuilder:
- """Utility to construct the well-known response
-
- Args:
- hs (synapse.server.HomeServer):
- """
-
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._config = hs.config
- def get_well_known(self):
+ def get_well_known(self) -> Optional[JsonDict]:
# if we don't have a public_baseurl, we can't help much here.
if self._config.server.public_baseurl is None:
return None
@@ -52,11 +52,11 @@ class WellKnownResource(Resource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self._well_known_builder = WellKnownBuilder(hs)
- def render_GET(self, request):
+ def render_GET(self, request: Request) -> bytes:
set_cors_headers(request)
r = self._well_known_builder.get_well_known()
if not r:
diff --git a/synapse/server.py b/synapse/server.py
index 4777ef585d..637eb15b78 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -392,7 +392,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
- if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+ if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
return InsecureInterceptableContextFactory()
return RegularPolicyForHTTPS()
@@ -418,8 +418,8 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
return SimpleHttpClient(
self,
- ip_whitelist=self.config.ip_range_whitelist,
- ip_blacklist=self.config.ip_range_blacklist,
+ ip_whitelist=self.config.server.ip_range_whitelist,
+ ip_blacklist=self.config.server.ip_range_blacklist,
use_proxy=True,
)
@@ -801,18 +801,18 @@ class HomeServer(metaclass=abc.ABCMeta):
logger.info(
"Connecting to redis (host=%r port=%r) for external cache",
- self.config.redis_host,
- self.config.redis_port,
+ self.config.redis.redis_host,
+ self.config.redis.redis_port,
)
return lazyConnection(
hs=self,
- host=self.config.redis_host,
- port=self.config.redis_port,
+ host=self.config.redis.redis_host,
+ port=self.config.redis.redis_port,
password=self.config.redis.redis_password,
reconnect=True,
)
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
- return self.config.send_federation
+ return self.config.worker.send_federation
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0084d9f96c..f5a8f90a0f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -1632,7 +1632,7 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
column: str,
- iterable: Iterable[Any],
+ iterable: Collection[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
@@ -1891,29 +1891,32 @@ class DatabasePool:
txn: LoggingTransaction,
table: str,
column: str,
- iterable: Iterable[Any],
+ values: Collection[Any],
keyvalues: Dict[str, Any],
) -> int:
"""Executes a DELETE query on the named table.
- Filters rows by if value of `column` is in `iterable`.
+ Deletes the rows:
+ - whose value of `column` is in `values`; AND
+ - that match extra column-value pairs specified in `keyvalues`.
Args:
txn: Transaction object
table: string giving the table name
- column: column name to test for inclusion against `iterable`
- iterable: list
- keyvalues: dict of column names and values to select the rows with
+ column: column name to test for inclusion against `values`
+ values: values of `column` which choose rows to delete
+ keyvalues: dict of extra column names and values to select the rows
+ with. They will be ANDed together with the main predicate.
Returns:
Number rows deleted
"""
- if not iterable:
+ if not values:
return 0
sql = "DELETE FROM %s" % table
- clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, values)
clauses = [clause]
for key, value in keyvalues.items():
@@ -2098,7 +2101,7 @@ class DatabasePool:
def make_in_list_sql_clause(
- database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
+ database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1dc347f0c9..5c21402dea 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -61,6 +61,7 @@ from .registration import RegistrationStore
from .rejections import RejectionsStore
from .relations import RelationsStore
from .room import RoomStore
+from .room_batch import RoomBatchStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .session import SessionStore
@@ -81,6 +82,7 @@ class DataStore(
EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
+ RoomBatchStore,
RegistrationStore,
StreamStore,
ProfileStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 1d02795f43..d0cf3460da 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -494,7 +494,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn,
table="ignored_users",
column="ignored_user_id",
- iterable=previously_ignored_users - currently_ignored_users,
+ values=previously_ignored_users - currently_ignored_users,
keyvalues={"ignorer_user_id": user_id},
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 047782eb06..10184d6ae7 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1034,13 +1034,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
LIMIT ?
"""
- # Find any chunk connections of a given insertion event
- chunk_connection_query = """
+ # Find any batch connections of a given insertion event
+ batch_connection_query = """
SELECT e.depth, c.event_id FROM insertion_events AS i
- /* Find the chunk that connects to the given insertion event */
- INNER JOIN chunk_events AS c
- ON i.next_chunk_id = c.chunk_id
- /* Get the depth of the chunk start event from the events table */
+ /* Find the batch that connects to the given insertion event */
+ INNER JOIN batch_events AS c
+ ON i.next_batch_id = c.batch_id
+ /* Get the depth of the batch start event from the events table */
INNER JOIN events AS e USING (event_id)
/* Find an insertion event which matches the given event_id */
WHERE i.event_id = ?
@@ -1077,12 +1077,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.add(event_id)
- # Try and find any potential historical chunks of message history.
+ # Try and find any potential historical batches of message history.
#
# First we look for an insertion event connected to the current
# event (by prev_event). If we find any, we need to go and try to
- # find any chunk events connected to the insertion event (by
- # chunk_id). If we find any, we'll add them to the queue and
+ # find any batch events connected to the insertion event (by
+ # batch_id). If we find any, we'll add them to the queue and
# navigate up the DAG like normal in the next iteration of the loop.
txn.execute(
connected_insertion_event_query, (event_id, limit - len(event_results))
@@ -1097,17 +1097,17 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
connected_insertion_event = row[1]
queue.put((-connected_insertion_event_depth, connected_insertion_event))
- # Find any chunk connections for the given insertion event
+ # Find any batch connections for the given insertion event
txn.execute(
- chunk_connection_query,
+ batch_connection_query,
(connected_insertion_event, limit - len(event_results)),
)
- chunk_start_event_id_results = txn.fetchall()
+ batch_start_event_id_results = txn.fetchall()
logger.debug(
- "_get_backfill_events: chunk_start_event_id_results %s",
- chunk_start_event_id_results,
+ "_get_backfill_events: batch_start_event_id_results %s",
+ batch_start_event_id_results,
)
- for row in chunk_start_event_id_results:
+ for row in batch_start_event_id_results:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8e691678e5..584f818ff3 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -667,7 +667,7 @@ class PersistEventsStore:
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
- iterable=new_chain_tuples,
+ values=new_chain_tuples,
)
# Now we need to calculate any new links between chains caused by
@@ -1509,7 +1509,7 @@ class PersistEventsStore:
self._handle_event_relations(txn, event)
self._handle_insertion_event(txn, event)
- self._handle_chunk_event(txn, event)
+ self._handle_batch_event(txn, event)
# Store the labels for this event.
labels = event.content.get(EventContentFields.LABELS)
@@ -1790,23 +1790,23 @@ class PersistEventsStore:
):
return
- next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID)
- if next_chunk_id is None:
- # Invalid insertion event without next chunk ID
+ next_batch_id = event.content.get(EventContentFields.MSC2716_NEXT_BATCH_ID)
+ if next_batch_id is None:
+ # Invalid insertion event without next batch ID
return
logger.debug(
- "_handle_insertion_event (next_chunk_id=%s) %s", next_chunk_id, event
+ "_handle_insertion_event (next_batch_id=%s) %s", next_batch_id, event
)
- # Keep track of the insertion event and the chunk ID
+ # Keep track of the insertion event and the batch ID
self.db_pool.simple_insert_txn(
txn,
table="insertion_events",
values={
"event_id": event.event_id,
"room_id": event.room_id,
- "next_chunk_id": next_chunk_id,
+ "next_batch_id": next_batch_id,
},
)
@@ -1822,8 +1822,8 @@ class PersistEventsStore:
},
)
- def _handle_chunk_event(self, txn: LoggingTransaction, event: EventBase):
- """Handles inserting the chunk edges/connections between the chunk event
+ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
+ """Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716.
Args:
@@ -1831,11 +1831,11 @@ class PersistEventsStore:
event: The event to process
"""
- if event.type != EventTypes.MSC2716_CHUNK:
- # Not a chunk event
+ if event.type != EventTypes.MSC2716_BATCH:
+ # Not a batch event
return
- # Skip processing a chunk event if the room version doesn't
+ # Skip processing a batch event if the room version doesn't
# support it or the event is not from the room creator.
room_version = self.store.get_room_version_txn(txn, event.room_id)
room_creator = self.db_pool.simple_select_one_onecol_txn(
@@ -1852,35 +1852,35 @@ class PersistEventsStore:
):
return
- chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID)
- if chunk_id is None:
- # Invalid chunk event without a chunk ID
+ batch_id = event.content.get(EventContentFields.MSC2716_BATCH_ID)
+ if batch_id is None:
+ # Invalid batch event without a batch ID
return
- logger.debug("_handle_chunk_event chunk_id=%s %s", chunk_id, event)
+ logger.debug("_handle_batch_event batch_id=%s %s", batch_id, event)
- # Keep track of the insertion event and the chunk ID
+ # Keep track of the insertion event and the batch ID
self.db_pool.simple_insert_txn(
txn,
- table="chunk_events",
+ table="batch_events",
values={
"event_id": event.event_id,
"room_id": event.room_id,
- "chunk_id": chunk_id,
+ "batch_id": batch_id,
},
)
- # When we receive an event with a `chunk_id` referencing the
- # `next_chunk_id` of the insertion event, we can remove it from the
+ # When we receive an event with a `batch_id` referencing the
+ # `next_batch_id` of the insertion event, we can remove it from the
# `insertion_event_extremities` table.
sql = """
DELETE FROM insertion_event_extremities WHERE event_id IN (
SELECT event_id FROM insertion_events
- WHERE next_chunk_id = ?
+ WHERE next_batch_id = ?
)
"""
- txn.execute(sql, (chunk_id,))
+ txn.execute(sql, (batch_id,))
def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 6fcb2b8353..1afc59fafb 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -490,7 +490,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="event_forward_extremities",
column="event_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
@@ -520,7 +520,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="_extremities_to_check",
column="event_id",
- iterable=original_set,
+ values=original_set,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 63ac09c61d..a93caae8d0 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -324,7 +324,7 @@ class PusherWorkerStore(SQLBaseStore):
txn,
table="pushers",
column="user_name",
- iterable=users,
+ values=users,
keyvalues={},
)
@@ -373,7 +373,7 @@ class PusherWorkerStore(SQLBaseStore):
txn,
table="pushers",
column="id",
- iterable=(pusher_id for pusher_id, token in pushers if token is None),
+ values=[pusher_id for pusher_id, token in pushers if token is None],
keyvalues={},
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index edeaacd7a6..01a4281301 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer
@@ -153,12 +153,12 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
async def get_linearized_receipts_for_rooms(
- self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+ self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients.
Args:
- room_id: List of room_ids.
+ room_id: The room IDs to fetch receipts of.
to_key: Max stream id to fetch receipts up to.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py
new file mode 100644
index 0000000000..a383388757
--- /dev/null
+++ b/synapse/storage/databases/main/room_batch.py
@@ -0,0 +1,36 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from synapse.storage._base import SQLBaseStore
+
+
+class RoomBatchStore(SQLBaseStore):
+ async def get_insertion_event_by_batch_id(self, batch_id: str) -> Optional[str]:
+ """Retrieve a insertion event ID.
+
+ Args:
+ batch_id: The batch ID of the insertion event to retrieve.
+
+ Returns:
+ The event_id of an insertion event, or None if there is no known
+ insertion event for the given insertion event.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ table="insertion_events",
+ keyvalues={"next_batch_id": batch_id},
+ retcol="event_id",
+ allow_none=True,
+ )
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 8e22da99ae..a8e8dd4577 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -473,7 +473,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
@@ -481,7 +481,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="event_forward_extremities",
column="room_id",
- iterable=to_delete,
+ values=to_delete,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 4d6bbc94c7..340ca9e47d 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -326,7 +326,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions_ips",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
@@ -377,7 +377,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions_credentials",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
@@ -386,7 +386,7 @@ class UIAuthWorkerStore(SQLBaseStore):
txn,
table="ui_auth_sessions",
column="session_id",
- iterable=session_ids,
+ values=session_ids,
keyvalues={},
)
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 8aebdc2817..718f3e9976 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -85,19 +85,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
- # If search all users is on, get all the users we want to add.
- if self.hs.config.user_directory_search_all_users:
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_users(user_id TEXT NOT NULL)"
- )
- txn.execute(sql)
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_users(user_id TEXT NOT NULL)"
+ )
+ txn.execute(sql)
- txn.execute("SELECT name FROM users")
- users = [{"user_id": x[0]} for x in txn.fetchall()]
+ txn.execute("SELECT name FROM users")
+ users = [{"user_id": x[0]} for x in txn.fetchall()]
- self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = await self.get_max_stream_id_in_current_state_deltas()
await self.db_pool.runInteraction(
@@ -265,13 +263,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def _populate_user_directory_process_users(self, progress, batch_size):
"""
- If search_all_users is enabled, add all of the users to the user directory.
+ Add all local users to the user directory.
"""
- if not self.hs.config.user_directory_search_all_users:
- await self.db_pool.updates._end_background_update(
- "populate_user_directory_process_users"
- )
- return 1
def _get_next_batch(txn):
sql = "SELECT user_id FROM %s LIMIT %s" % (
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index c2891cb07f..eb1118d2cb 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -13,12 +13,20 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""
- def _count_state_group_hops_txn(self, txn, state_group):
+ def _count_state_group_hops_txn(
+ self, txn: LoggingTransaction, state_group: int
+ ) -> int:
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
@@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
+ next_group: Optional[int] = state_group
count = 0
while next_group:
@@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter: Optional[StateFilter] = None
- ):
+ self,
+ txn: LoggingTransaction,
+ groups: List[int],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all()
- results = {group: {} for group in groups}
+ results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
@@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""
for group in groups:
- args = [group]
+ args: List[Union[int, str]] = [group]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
@@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
- next_group = group
+ next_group: Optional[int] = group
while next_group:
# We did this before by getting the list of group ids, and
@@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
allow_none=True,
)
+ # The results shouldn't be considered mutable.
return results
@@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
- async def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(
+ self, progress: dict, batch_size: int
+ ) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
max_group = rows[0][0]
- def reindex_txn(txn):
+ def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
@@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
- (prev_group,) = txn.fetchone()
+ # There will be a result due to the coalesce.
+ (prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group
if prev_group:
@@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades.
continue
- prev_state = self._get_state_groups_from_groups_txn(
+ prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group]
)
- prev_state = prev_state[prev_group]
+ prev_state = prev_state_by_group[prev_group]
- curr_state = self._get_state_groups_from_groups_txn(
+ curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group]
)
- curr_state = curr_state[state_group]
+ curr_state = curr_state_by_group[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
@@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return result * BATCH_SIZE_SCALE_FACTOR
- async def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
+ async def _background_index_state(self, progress: dict, batch_size: int) -> int:
+ def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c24f..c4c8c0021b 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,43 +13,56 @@
# limitations under the License.
import logging
-from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+
+import attr
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _GetStateGroupDelta:
"""Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
+ us use the iterable flag when caching
"""
- __slots__ = []
+ prev_group: Optional[int]
+ delta_ids: Optional[StateMap[str]]
- def __len__(self):
+ def __len__(self) -> int:
return len(self.delta_ids) if self.delta_ids else 0
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups."""
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
@@ -81,19 +94,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# We size the non-members cache to be smaller than the members cache as the
# vast majority of state in Matrix (today) is member events.
- self._state_group_cache = DictionaryCache(
+ self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000,
)
- self._state_group_members_cache = DictionaryCache(
+ self._state_group_members_cache: DictionaryCache[
+ int, StateKey, str
+ ] = DictionaryCache(
"*stateGroupMembersCache*",
500000,
)
- def get_max_state_group_txn(txn: Cursor):
+ def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
- return txn.fetchone()[0]
+ return txn.fetchone()[0] # type: ignore
self._state_group_seq_gen = build_sequence_generator(
db_conn,
@@ -105,15 +120,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@cached(max_entries=10000, iterable=True)
- async def get_state_group_delta(self, state_group):
+ async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- (prev_group, delta_ids), where both may be None.
+ _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
"""
- def _get_state_group_delta_txn(txn):
+ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
prev_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
@@ -154,7 +169,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
Dict of state group to state map.
"""
- results = {}
+ results: Dict[int, StateMap[str]] = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
@@ -168,19 +183,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return results
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ def _get_state_for_group_using_cache(
+ self,
+ cache: DictionaryCache[int, StateKey, str],
+ group: int,
+ state_filter: StateFilter,
+ ) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ cache: the state group cache to use
+ group: The state group to lookup
+ state_filter: The state filter used to fetch state from the database.
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
+ Returns:
+ 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
"""
cache_entry = cache.get(group)
state_dict_ids = cache_entry.value
@@ -277,8 +297,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
def _get_state_for_groups_using_cache(
- self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
- ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+ self,
+ groups: Iterable[int],
+ cache: DictionaryCache[int, StateKey, str],
+ state_filter: StateFilter,
+ ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -310,21 +333,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
def _insert_into_cache(
self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
+ group_to_state_dict: Dict[int, StateMap[str]],
+ state_filter: StateFilter,
+ cache_seq_num_members: int,
+ cache_seq_num_non_members: int,
+ ) -> None:
"""Inserts results from querying the database into the relevant cache.
Args:
- group_to_state_dict (dict): The new entries pulled from database.
+ group_to_state_dict: The new entries pulled from database.
Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
+ state_filter: The state filter used to fetch state
from the database.
- cache_seq_num_members (int): Sequence number of member cache since
+ cache_seq_num_members: Sequence number of member cache since
last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
+ cache_seq_num_non_members: Sequence number of member cache since
last lookup in cache
"""
@@ -395,7 +418,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID
"""
- def _store_state_group_txn(txn):
+ def _store_state_group_txn(txn: LoggingTransaction) -> int:
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
@@ -426,6 +449,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ assert delta_ids is not None
+
self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
@@ -498,7 +523,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def purge_unreferenced_state_groups(
- self, room_id: str, state_groups_to_delete
+ self, room_id: str, state_groups_to_delete: Collection[int]
) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@@ -506,8 +531,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Args:
room_id: The room the state groups belong to (must all be in the
same room).
- state_groups_to_delete (Collection[int]): Set of all state groups
- to delete.
+ state_groups_to_delete: Set of all state groups to delete.
"""
await self.db_pool.runInteraction(
@@ -517,7 +541,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ def _purge_unreferenced_state_groups(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
logger.info(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
@@ -546,8 +575,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# groups to non delta versions.
for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
+ curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state_by_group[sg]
self.db_pool.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
@@ -605,12 +634,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
- async def purge_room_state(self, room_id, state_groups_to_delete):
+ async def purge_room_state(
+ self, room_id: str, state_groups_to_delete: Collection[int]
+ ) -> None:
"""Deletes all record of a room from state tables
Args:
- room_id (str):
- state_groups_to_delete (list[int]): State groups to delete
+ room_id:
+ state_groups_to_delete: State groups to delete
"""
await self.db_pool.runInteraction(
@@ -620,7 +651,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete,
)
- def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ def _purge_room_state_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_groups_to_delete: Collection[int],
+ ) -> None:
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
@@ -628,7 +664,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_groups_state",
column="state_group",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
@@ -639,7 +675,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_group_edges",
column="state_group",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
@@ -650,6 +686,6 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
txn,
table="state_groups",
column="id",
- iterable=state_groups_to_delete,
+ values=state_groups_to_delete,
keyvalues={},
)
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index c552dbf04c..10a46b5e82 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -73,7 +73,7 @@ class RelationPaginationToken:
t, s = string.split("-")
return RelationPaginationToken(int(t), int(s))
except ValueError:
- raise SynapseError(400, "Invalid token")
+ raise SynapseError(400, "Invalid relation pagination token")
def to_string(self) -> str:
return "%d-%d" % (self.topological, self.stream)
@@ -103,7 +103,7 @@ class AggregationPaginationToken:
c, s = string.split("-")
return AggregationPaginationToken(int(c), int(s))
except ValueError:
- raise SynapseError(400, "Invalid token")
+ raise SynapseError(400, "Invalid aggregation pagination token")
def to_string(self) -> str:
return "%d-%d" % (self.count, self.stream)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index af9cc69949..aa2ce44c6c 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -14,7 +14,7 @@
# When updating these values, please leave a short summary of the changes below.
-SCHEMA_VERSION = 63
+SCHEMA_VERSION = 64
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
diff --git a/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.postgres b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.postgres
new file mode 100644
index 0000000000..5f38993208
--- /dev/null
+++ b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.postgres
@@ -0,0 +1,23 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE insertion_events RENAME COLUMN next_chunk_id TO next_batch_id;
+DROP INDEX insertion_events_next_chunk_id;
+CREATE INDEX IF NOT EXISTS insertion_events_next_batch_id ON insertion_events(next_batch_id);
+
+ALTER TABLE chunk_events RENAME TO batch_events;
+ALTER TABLE batch_events RENAME COLUMN chunk_id TO batch_id;
+DROP INDEX chunk_events_chunk_id;
+CREATE INDEX IF NOT EXISTS batch_events_batch_id ON batch_events(batch_id);
diff --git a/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.sqlite b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.sqlite
new file mode 100644
index 0000000000..4989563995
--- /dev/null
+++ b/synapse/storage/schema/main/delta/64/01msc2716_chunk_to_batch_rename.sql.sqlite
@@ -0,0 +1,37 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Re-create the insertion_events table since SQLite doesn't support better
+-- renames for columns (next_chunk_id -> next_batch_id)
+DROP TABLE insertion_events;
+CREATE TABLE IF NOT EXISTS insertion_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ next_batch_id TEXT NOT NULL
+);
+CREATE UNIQUE INDEX IF NOT EXISTS insertion_events_event_id ON insertion_events(event_id);
+CREATE INDEX IF NOT EXISTS insertion_events_next_batch_id ON insertion_events(next_batch_id);
+
+-- Re-create the chunk_events table since SQLite doesn't support better renames
+-- for columns (chunk_id -> batch_id)
+DROP TABLE chunk_events;
+CREATE TABLE IF NOT EXISTS batch_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ batch_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS batch_events_event_id ON batch_events(event_id);
+CREATE INDEX IF NOT EXISTS batch_events_batch_id ON batch_events(batch_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e5400d681a..5e86befde4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -25,12 +25,15 @@ from typing import (
)
import attr
+from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
+ from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
+
from synapse.server import HomeServer
from synapse.storage.databases import Databases
@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
class StateFilter:
"""A filter used when querying for state.
@@ -53,14 +56,19 @@ class StateFilter:
appear in `types`.
"""
- types = attr.ib(type=Dict[str, Optional[Set[str]]])
+ types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
- self.types = {k: v for k, v in self.types.items() if v is not None}
+ # this is needed to work around the fact that StateFilter is frozen
+ object.__setattr__(
+ self,
+ "types",
+ frozendict({k: v for k, v in self.types.items() if v is not None}),
+ )
@staticmethod
def all() -> "StateFilter":
@@ -69,7 +77,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=True)
+ return StateFilter(types=frozendict(), include_others=True)
@staticmethod
def none() -> "StateFilter":
@@ -78,7 +86,7 @@ class StateFilter:
Returns:
The new state filter.
"""
- return StateFilter(types={}, include_others=False)
+ return StateFilter(types=frozendict(), include_others=False)
@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +111,12 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore
- return StateFilter(types=type_dict)
+ return StateFilter(
+ types=frozendict(
+ (k, frozenset(v) if v is not None else None)
+ for k, v in type_dict.items()
+ )
+ )
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +129,10 @@ class StateFilter:
Returns:
The new state filter
"""
- return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
+ return StateFilter(
+ types=frozendict({EventTypes.Member: frozenset(members)}),
+ include_others=True,
+ )
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +189,7 @@ class StateFilter:
# We want to return all non-members, but only particular
# memberships
return StateFilter(
- types={EventTypes.Member: self.types[EventTypes.Member]},
+ types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)
@@ -245,14 +261,15 @@ class StateFilter:
return len(self.concrete_types())
- def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
- """Returns the state filtered with by this StateFilter
+ def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
+ """Returns the state filtered with by this StateFilter.
Args:
state: The state map to filter
Returns:
- The filtered state map
+ The filtered state map.
+ This is a copy, so it's safe to mutate.
"""
if self.is_full():
return dict(state_dict)
@@ -324,14 +341,16 @@ class StateFilter:
if state_keys is None:
member_filter = StateFilter.all()
else:
- member_filter = StateFilter({EventTypes.Member: state_keys})
+ member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others:
member_filter = StateFilter.all()
else:
member_filter = StateFilter.none()
non_member_filter = StateFilter(
- types={k: v for k, v in self.types.items() if k != EventTypes.Member},
+ types=frozendict(
+ {k: v for k, v in self.types.items() if k != EventTypes.Member}
+ ),
include_others=self.include_others,
)
@@ -358,7 +377,8 @@ class StateGroupStorage:
make up the delta between the old and new state groups.
"""
- return await self.stores.state.get_state_group_delta(state_group)
+ state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+ return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 5e83dba2ed..806b671305 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -11,3 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from typing import Collection, Generic, List, Optional, Tuple, TypeVar
+
+from synapse.types import UserID
+
+# The key, this is either a stream token or int.
+K = TypeVar("K")
+# The return type.
+R = TypeVar("R")
+
+
+class EventSource(Generic[K, R]):
+ async def get_new_events(
+ self,
+ user: UserID,
+ from_key: K,
+ limit: Optional[int],
+ room_ids: Collection[str],
+ is_guest: bool,
+ explicit_room_id: Optional[str] = None,
+ ) -> Tuple[List[R], K]:
+ ...
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 99b0aac2fb..21591d0bfd 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -12,29 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Iterator, Tuple
+
+import attr
from synapse.handlers.account_data import AccountDataEventSource
from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.streams import EventSource
from synapse.types import StreamToken
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
-class EventSources:
- SOURCE_TYPES = {
- "room": RoomEventSource,
- "presence": PresenceEventSource,
- "typing": TypingNotificationEventSource,
- "receipt": ReceiptEventSource,
- "account_data": AccountDataEventSource,
- }
- def __init__(self, hs):
- self.sources: Dict[str, Any] = {
- name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
- }
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _EventSourcesInner:
+ room: RoomEventSource
+ presence: PresenceEventSource
+ typing: TypingNotificationEventSource
+ receipt: ReceiptEventSource
+ account_data: AccountDataEventSource
+
+ def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
+ for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined]
+ yield attribute.name, getattr(self, attribute.name)
+
+
+class EventSources:
+ def __init__(self, hs: "HomeServer"):
+ self.sources = _EventSourcesInner(
+ *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined]
+ )
self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken:
@@ -44,11 +55,11 @@ class EventSources:
groups_key = self.store.get_group_stream_token()
token = StreamToken(
- room_key=self.sources["room"].get_current_key(),
- presence_key=self.sources["presence"].get_current_key(),
- typing_key=self.sources["typing"].get_current_key(),
- receipt_key=self.sources["receipt"].get_current_key(),
- account_data_key=self.sources["account_data"].get_current_key(),
+ room_key=self.sources.room.get_current_key(),
+ presence_key=self.sources.presence.get_current_key(),
+ typing_key=self.sources.typing.get_current_key(),
+ receipt_key=self.sources.receipt.get_current_key(),
+ account_data_key=self.sources.account_data.get_current_key(),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
@@ -67,7 +78,7 @@ class EventSources:
The current token for pagination.
"""
token = StreamToken(
- room_key=self.sources["room"].get_current_key(),
+ room_key=self.sources.room.get_current_key(),
presence_key=0,
typing_key=0,
receipt_key=0,
diff --git a/synapse/types.py b/synapse/types.py
index d4759b2dfd..90168ce8fa 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -511,7 +511,7 @@ class RoomStreamToken:
)
except Exception:
pass
- raise SynapseError(400, "Invalid token %r" % (string,))
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
def parse_stream_token(cls, string: str) -> "RoomStreamToken":
@@ -520,7 +520,7 @@ class RoomStreamToken:
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
- raise SynapseError(400, "Invalid token %r" % (string,))
+ raise SynapseError(400, "Invalid room stream token %r" % (string,))
def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
"""Return a new token such that if an event is after both this token and
@@ -619,7 +619,7 @@ class StreamToken:
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
)
except Exception:
- raise SynapseError(400, "Invalid Token")
+ raise SynapseError(400, "Invalid stream token")
async def to_string(self, store: "DataStore") -> str:
return self._SEPARATOR.join(
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index cab1bf0c15..df4d61e4b6 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import collections
import logging
+import typing
+from enum import Enum, auto
from sys import intern
from typing import Callable, Dict, Optional, Sized
@@ -34,7 +36,7 @@ collectors_by_name: Dict[str, "CacheMetric"] = {}
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
-cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name", "reason"])
cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"])
cache_memory_usage = Gauge(
@@ -46,11 +48,16 @@ cache_memory_usage = Gauge(
response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
response_cache_evicted = Gauge(
- "synapse_util_caches_response_cache:evicted_size", "", ["name"]
+ "synapse_util_caches_response_cache:evicted_size", "", ["name", "reason"]
)
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+class EvictionReason(Enum):
+ size = auto()
+ time = auto()
+
+
@attr.s(slots=True)
class CacheMetric:
@@ -61,7 +68,9 @@ class CacheMetric:
hits = attr.ib(default=0)
misses = attr.ib(default=0)
- evicted_size = attr.ib(default=0)
+ eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
+ factory=collections.Counter
+ )
memory_usage = attr.ib(default=None)
def inc_hits(self) -> None:
@@ -70,8 +79,8 @@ class CacheMetric:
def inc_misses(self) -> None:
self.misses += 1
- def inc_evictions(self, size: int = 1) -> None:
- self.evicted_size += size
+ def inc_evictions(self, reason: EvictionReason, size: int = 1) -> None:
+ self.eviction_size_by_reason[reason] += size
def inc_memory_usage(self, memory: int) -> None:
if self.memory_usage is None:
@@ -94,14 +103,20 @@ class CacheMetric:
if self._cache_type == "response_cache":
response_cache_size.labels(self._cache_name).set(len(self._cache))
response_cache_hits.labels(self._cache_name).set(self.hits)
- response_cache_evicted.labels(self._cache_name).set(self.evicted_size)
+ for reason in EvictionReason:
+ response_cache_evicted.labels(self._cache_name, reason.name).set(
+ self.eviction_size_by_reason[reason]
+ )
response_cache_total.labels(self._cache_name).set(
self.hits + self.misses
)
else:
cache_size.labels(self._cache_name).set(len(self._cache))
cache_hits.labels(self._cache_name).set(self.hits)
- cache_evicted.labels(self._cache_name).set(self.evicted_size)
+ for reason in EvictionReason:
+ cache_evicted.labels(self._cache_name, reason.name).set(
+ self.eviction_size_by_reason[reason]
+ )
cache_total.labels(self._cache_name).set(self.hits + self.misses)
if getattr(self._cache, "max_size", None):
cache_max_size.labels(self._cache_name).set(self._cache.max_size)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index ade088aae2..485ddb1893 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -130,7 +130,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
sequence: int,
key: KT,
value: Dict[DKT, DV],
- fetched_keys: Optional[Set[DKT]] = None,
+ fetched_keys: Optional[Iterable[DKT]] = None,
) -> None:
"""Updates the entry in the cache
@@ -155,7 +155,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
self._update_or_insert(key, value, fetched_keys)
def _update_or_insert(
- self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
+ self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index bde16b8577..c3f72aa06d 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -22,7 +22,7 @@ from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
-from synapse.util.caches import register_cache
+from synapse.util.caches import EvictionReason, register_cache
logger = logging.getLogger(__name__)
@@ -98,9 +98,9 @@ class ExpiringCache(Generic[KT, VT]):
while self._max_size and len(self) > self._max_size:
_key, value = self._cache.popitem(last=False)
if self.iterable:
- self.metrics.inc_evictions(len(value.value))
+ self.metrics.inc_evictions(EvictionReason.size, len(value.value))
else:
- self.metrics.inc_evictions()
+ self.metrics.inc_evictions(EvictionReason.size)
def __getitem__(self, key: KT) -> VT:
try:
@@ -175,9 +175,9 @@ class ExpiringCache(Generic[KT, VT]):
for k in keys_to_delete:
value = self._cache.pop(k)
if self.iterable:
- self.metrics.inc_evictions(len(value.value))
+ self.metrics.inc_evictions(EvictionReason.time, len(value.value))
else:
- self.metrics.inc_evictions()
+ self.metrics.inc_evictions(EvictionReason.time)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 17cb98ff0b..4ff62b403f 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -40,7 +40,7 @@ from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util import Clock, caches
-from synapse.util.caches import CacheMetric, register_cache
+from synapse.util.caches import CacheMetric, EvictionReason, register_cache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.linked_list import ListNode
@@ -405,7 +405,7 @@ class LruCache(Generic[KT, VT]):
evicted_len = delete_node(node)
cache.pop(node.key, None)
if metrics:
- metrics.inc_evictions(evicted_len)
+ metrics.inc_evictions(EvictionReason.size, evicted_len)
def synchronized(f: FT) -> FT:
@wraps(f)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 2928c4f48c..57cc3e2646 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
import synapse.types
from synapse.api.errors import AuthError, SynapseError
+from synapse.rest import admin
from synapse.types import UserID
from tests import unittest
@@ -25,6 +26,8 @@ from tests.test_utils import make_awaitable
class ProfileTestCase(unittest.HomeserverTestCase):
"""Tests profile management."""
+ servlets = [admin.register_servlets]
+
def make_homeserver(self, reactor, clock):
self.mock_federation = Mock()
self.mock_registry = Mock()
@@ -46,11 +49,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.frank = UserID.from_string("@1234ABCD:test")
+ self.frank = UserID.from_string("@1234abcd:test")
self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote")
- self.get_success(self.store.create_profile(self.frank.localpart))
+ self.get_success(self.register_user(self.frank.localpart, "frankpassword"))
self.handler = hs.get_profile_handler()
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 732a12c9bd..5de89c873b 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -23,7 +23,7 @@ from tests import unittest
class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
- self.event_source = hs.get_event_sources().sources["receipt"]
+ self.event_source = hs.get_event_sources().sources.receipt
# In the first param of _test_filters_hidden we use "hidden" instead of
# ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index fa3cff598e..000f9b9fde 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -89,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_typing_handler()
- self.event_source = hs.get_event_sources().sources["typing"]
+ self.event_source = hs.get_event_sources().sources.typing
self.datastore = hs.get_datastore()
self.datastore.get_destination_retry_timings = Mock(
@@ -171,7 +171,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ )
)
self.assertEquals(
events[0],
@@ -239,7 +241,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ )
)
self.assertEquals(
events[0],
@@ -276,7 +280,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=0,
+ limit=None,
+ room_ids=[OTHER_ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(events[0], [])
self.assertEquals(events[1], 0)
@@ -324,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+ )
)
self.assertEquals(
events[0],
@@ -350,7 +362,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=0,
+ limit=None,
+ room_ids=[ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
@@ -369,7 +387,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 2)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=1,
+ limit=None,
+ room_ids=[ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
@@ -392,7 +416,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 3)
events = self.get_success(
- self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ self.event_source.get_new_events(
+ user=U_APPLE,
+ from_key=0,
+ limit=None,
+ room_ids=[ROOM_ID],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index ac419f0db3..01b1b0d4a0 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -1,4 +1,4 @@
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
# limitations under the License.
import logging
import os
-from binascii import unhexlify
from typing import Optional, Tuple
from twisted.internet.protocol import Factory
@@ -28,6 +27,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
+from tests.test_utils import SMALL_PNG
logger = logging.getLogger(__name__)
@@ -190,31 +190,25 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
- png_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
- request1.write(png_data)
+ request1.write(SMALL_PNG)
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
- self.assertEqual(channel1.result["body"], png_data)
+ self.assertEqual(channel1.result["body"], SMALL_PNG)
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"])
- request2.write(png_data)
+ request2.write(SMALL_PNG)
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
- self.assertEqual(channel2.result["body"], png_data)
+ self.assertEqual(channel2.result["body"], SMALL_PNG)
# We expect only three new thumbnails to have been persisted.
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index bfa638fb4b..febd40b656 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
import json
import os
import urllib.parse
-from binascii import unhexlify
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -28,6 +27,7 @@ from synapse.rest.client import groups, login, room
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
class VersionTestCase(unittest.HomeserverTestCase):
@@ -150,11 +150,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
- self.image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
def make_homeserver(self, reactor, clock):
@@ -266,7 +261,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media into the room
response = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=admin_user_tok
)
# Extract media ID from the response
@@ -314,10 +309,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract mxcs
@@ -381,10 +376,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract media IDs
@@ -421,10 +416,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
- self.upload_resource, self.image_data, tok=non_admin_user_tok
+ self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract media IDs
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 972d60570c..2f02934e72 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -1,4 +1,5 @@
# Copyright 2020 Dirk Klimpel
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +15,6 @@
import json
import os
-from binascii import unhexlify
from parameterized import parameterized
@@ -25,6 +25,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@@ -110,15 +111,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
download_resource = self.media_repo.children[b"download"]
upload_resource = self.media_repo.children[b"upload"]
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -504,16 +500,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
Create a media and return media_id and server_and_media_id
"""
upload_resource = self.media_repo.children[b"upload"]
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -584,16 +574,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
# Create media
upload_resource = media_repo.children[b"upload"]
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
@@ -711,16 +695,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
# Create media
upload_resource = media_repo.children[b"upload"]
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
# Upload some media into the room
response = self.helper.upload_media(
- upload_resource, image_data, tok=self.admin_user_tok, expect_code=200
+ upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 40e032df7f..e798513ac1 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -941,6 +941,33 @@ class RoomTestCase(unittest.HomeserverTestCase):
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
+ def test_search_term_non_ascii(self):
+ """Test that searching for a room with non-ASCII characters works correctly"""
+
+ # Create test room
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+ room_name = "ж"
+
+ # Set the name for the room
+ self.helper.send_state(
+ room_id,
+ "m.room.name",
+ {"name": room_name},
+ tok=self.admin_user_tok,
+ )
+
+ # make the request and test that the response is what we wanted
+ search_term = urllib.parse.quote("ж", "utf-8")
+ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
+ self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
+
def test_single_room(self):
"""Test that a single room can be requested correctly"""
# Create two test rooms
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 5cd82209c4..ece89a65ac 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -1,4 +1,5 @@
# Copyright 2020 Dirk Klimpel
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,6 @@
# limitations under the License.
import json
-from binascii import unhexlify
from typing import Any, Dict, List, Optional
import synapse.rest.admin
@@ -21,6 +21,7 @@ from synapse.api.errors import Codes
from synapse.rest.client import login
from tests import unittest
+from tests.test_utils import SMALL_PNG
class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
@@ -468,16 +469,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
upload_resource = self.media_repo.children[b"upload"]
for _ in range(number_media):
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
# Upload some media into the room
self.helper.upload_media(
- upload_resource, image_data, tok=user_token, expect_code=200
+ upload_resource, SMALL_PNG, tok=user_token, expect_code=200
)
def _check_fields(self, content: List[Dict[str, Any]]):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ee204c404b..cc3f16c62a 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeSite, make_request
-from tests.test_utils import make_awaitable
+from tests.test_utils import SMALL_PNG, make_awaitable
from tests.unittest import override_config
@@ -2835,11 +2835,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
- image_data1 = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
+ image_data1 = SMALL_PNG
# Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
image_data2 = unhexlify(
b"47494638376101000100800100000000"
@@ -2943,14 +2939,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
"""
media_ids = []
for _ in range(number_media):
- # file size is 67 Byte
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
- media_ids.append(self._create_media_and_access(user_token, image_data))
+ media_ids.append(self._create_media_and_access(user_token, SMALL_PNG))
return media_ids
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 50100a5ae4..ef847f0f5f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -26,7 +26,7 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
@@ -377,6 +377,91 @@ class RoomPermissionsTestCase(RoomBase):
expect_code=403,
)
+ # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
+ def test_member_event_from_ban(self):
+ room = self.created_rmid
+ self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
+ self.helper.join(room=room, user=self.user_id)
+
+ other = "@burgundy:red"
+
+ # User cannot ban other since they do not have required power level
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.BAN,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.FORBIDDEN,
+ )
+
+ # Admin bans other
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.BAN,
+ expect_code=200,
+ )
+
+ # from ban to invite: Must never happen.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.INVITE,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.BAD_STATE,
+ )
+
+ # from ban to join: Must never happen.
+ self.helper.change_membership(
+ room=room,
+ src=other,
+ targ=other,
+ membership=Membership.JOIN,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.BAD_STATE,
+ )
+
+ # from ban to ban: No change.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.BAN,
+ expect_code=200,
+ )
+
+ # from ban to knock: Must never happen.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.KNOCK,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.BAD_STATE,
+ )
+
+ # User cannot unban other since they do not have required power level
+ self.helper.change_membership(
+ room=room,
+ src=self.user_id,
+ targ=other,
+ membership=Membership.LEAVE,
+ expect_code=403, # expect failure
+ expect_errcode=Codes.FORBIDDEN,
+ )
+
+ # from ban to leave: User was unbanned.
+ self.helper.change_membership(
+ room=room,
+ src=self.rmcreator_id,
+ targ=other,
+ membership=Membership.LEAVE,
+ expect_code=200,
+ )
+
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
@@ -784,6 +869,12 @@ class RoomJoinRatelimitTestCase(RoomBase):
room.register_servlets,
]
+ def prepare(self, reactor, clock, homeserver):
+ super().prepare(reactor, clock, homeserver)
+ # profile changes expect that the user is actually registered
+ user = UserID.from_string(self.user_id)
+ self.get_success(self.register_user(user.localpart, "supersecretpassword"))
+
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
@@ -813,12 +904,6 @@ class RoomJoinRatelimitTestCase(RoomBase):
# join in a second.
room_ids.append(self.helper.create_room_as(self.user_id))
- # Create a profile for the user, since it hasn't been done on registration.
- store = self.hs.get_datastore()
- self.get_success(
- store.create_profile(UserID.from_string(self.user_id).localpart)
- )
-
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
channel = self.make_request("PUT", path, {"displayname": "John Doe"})
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 6a0d9a82be..b0c44af033 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -193,7 +193,7 @@ class RoomTestCase(_ShadowBannedBase):
self.assertEquals(200, channel.code)
# There should be no typing events.
- event_source = self.hs.get_event_sources().sources["typing"]
+ event_source = self.hs.get_event_sources().sources.typing
self.assertEquals(event_source.get_current_key(), 0)
# The other user can join and send typing events.
@@ -210,7 +210,13 @@ class RoomTestCase(_ShadowBannedBase):
# These appear in the room.
self.assertEquals(event_source.get_current_key(), 1)
events = self.get_success(
- event_source.get_new_events(from_key=0, room_ids=[room_id])
+ event_source.get_new_events(
+ user=UserID.from_string(self.other_user_id),
+ from_key=0,
+ limit=None,
+ room_ids=[room_id],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0ae4029640..38ac9be113 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -15,6 +15,7 @@ import threading
from typing import Dict
from unittest.mock import Mock
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi
@@ -327,3 +328,86 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
+
+ def test_sent_event_end_up_in_room_state(self):
+ """Tests that a state event sent by a module while processing another state event
+ doesn't get dropped from the state of the room. This is to guard against a bug
+ where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
+ """
+ event_type = "org.matrix.test_state"
+
+ # This content will be updated later on, and since we actually use a reference on
+ # the dict it does the right thing. It's a bit hacky but a handy way of making
+ # sure the state actually gets updated.
+ event_content = {"i": -1}
+
+ api = self.hs.get_module_api()
+
+ # Define a callback that sends a custom event on power levels update.
+ async def test_fn(event: EventBase, state_events):
+ if event.is_state and event.type == EventTypes.PowerLevels:
+ await api.create_and_send_event_into_room(
+ {
+ "room_id": event.room_id,
+ "sender": event.sender,
+ "type": event_type,
+ "content": event_content,
+ "state_key": "",
+ }
+ )
+ return True, None
+
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [test_fn]
+
+ # Sometimes the bug might not happen the first time the event type is added
+ # to the state but might happen when an event updates the state of the room for
+ # that type, so we test updating the state several times.
+ for i in range(5):
+ # Update the content of the custom state event to be sent by the callback.
+ event_content["i"] = i
+
+ # Update the room's power levels with a different value each time so Synapse
+ # doesn't consider an update redundant.
+ self._update_power_levels(event_default=i)
+
+ # Check that the new event made it to the room's state.
+ channel = self.make_request(
+ method="GET",
+ path="/rooms/" + self.room_id + "/state/" + event_type,
+ access_token=self.tok,
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["i"], i)
+
+ def _update_power_levels(self, event_default: int = 0):
+ """Updates the room's power levels.
+
+ Args:
+ event_default: Value to use for 'events_default'.
+ """
+ self.helper.send_state(
+ room_id=self.room_id,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "ban": 50,
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.encryption": 100,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ "m.room.server_acl": 100,
+ "m.room.tombstone": 100,
+ },
+ "events_default": event_default,
+ "invite": 0,
+ "kick": 50,
+ "redact": 50,
+ "state_default": 50,
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ },
+ tok=self.tok,
+ )
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index b54b004733..ee0abd5295 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -41,7 +41,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
federation_client=Mock(),
)
- self.event_source = hs.get_event_sources().sources["typing"]
+ self.event_source = hs.get_event_sources().sources.typing
hs.get_federation_handler = Mock()
@@ -76,7 +76,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
events = self.get_success(
- self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ self.event_source.get_new_events(
+ user=UserID.from_string(self.user_id),
+ from_key=0,
+ limit=None,
+ room_ids=[self.room_id],
+ is_guest=False,
+ )
)
self.assertEquals(
events[0],
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 954ad1a1fd..c56e45fc10 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -138,6 +138,7 @@ class RestHelper:
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
+ expect_errcode: str = None,
) -> None:
"""
Send a membership state event into a room.
@@ -150,6 +151,7 @@ class RestHelper:
extra_data: Extra information to include in the content of the event
tok: The user access token to use
expect_code: The expected HTTP response code
+ expect_errcode: The expected Matrix error code
"""
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -177,6 +179,15 @@ class RestHelper:
channel.result["body"],
)
+ if expect_errcode:
+ assert (
+ str(channel.json_body["errcode"]) == expect_errcode
+ ), "Expected: %r, got: %r, resp: %r" % (
+ expect_errcode,
+ channel.json_body["errcode"],
+ channel.result["body"],
+ )
+
self.auth_user_id = temp_id
def send(
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2f7eebfe69..9ea1c2bf25 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,6 +38,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.test_utils import SMALL_PNG
from tests.utils import default_config
@@ -134,11 +135,7 @@ class _TestImage:
# smoll png
(
_TestImage(
- unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- ),
+ SMALL_PNG,
b"image/png",
b".png",
unhexlify(
@@ -593,15 +590,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
def test_upload_innocent(self):
"""Attempt to upload some innocent data that should be allowed."""
-
- image_data = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
-
self.helper.upload_media(
- self.upload_resource, image_data, tok=self.tok, expect_code=200
+ self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
def test_upload_ban(self):
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 9f6fbfe6de..9d13899584 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -24,6 +24,7 @@ from synapse.config.oembed import OEmbedEndpointConfig
from tests import unittest
from tests.server import FakeTransport
+from tests.test_utils import SMALL_PNG
try:
import lxml
@@ -576,13 +577,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
}
oembed_content = json.dumps(result).encode("utf-8")
- end_content = (
- b"<html><head>"
- b"<title>Some Title</title>"
- b'<meta property="og:description" content="hi" />'
- b"</head></html>"
- )
-
channel = self.make_request(
"GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345",
@@ -606,6 +600,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
+ # Ensure a second request is made to the photo URL.
client = self.reactor.tcpClients[1][2].buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
@@ -613,18 +608,23 @@ class URLPreviewTests(unittest.HomeserverTestCase):
client.dataReceived(
(
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
- b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+ b"Content-Type: image/png\r\n\r\n"
)
- % (len(end_content),)
- + end_content
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
)
self.pump()
+ # Ensure the URL is what was requested.
+ self.assertIn(b"/matrixdotorg", server.data)
+
self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
- )
+ self.assertIsNone(channel.json_body["og:title"])
+ self.assertTrue(channel.json_body["og:image"].startswith("mxc://"))
+ self.assertEqual(channel.json_body["og:image:height"], 1)
+ self.assertEqual(channel.json_body["og:image:width"], 1)
+ self.assertEqual(channel.json_body["og:image:type"], "image/png")
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8695264595..32060f2abd 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -14,6 +14,8 @@
import logging
+from frozendict import frozendict
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
@@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
- types={EventTypes.Member: {self.u_alice.to_string()}},
+ types=frozendict(
+ {EventTypes.Member: frozenset({self.u_alice.to_string()})}
+ ),
include_others=True,
),
)
@@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}),
+ include_others=True,
),
)
)
@@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=False
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=False,
),
)
@@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: set()}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset()}), include_others=True
),
)
@@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: None}, include_others=True
+ types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=True
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=True,
),
)
@@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=False
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=False,
),
)
@@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
- types={EventTypes.Member: {e5.state_key}}, include_others=False
+ types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+ include_others=False,
),
)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index be6302d170..15ac2bfeba 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -1,5 +1,4 @@
-# Copyright 2019 New Vector Ltd
-# Copyright 2020 The Matrix.org Foundation C.I.C
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,6 +18,7 @@ Utilities for running the unit tests
import sys
import warnings
from asyncio import Future
+from binascii import unhexlify
from typing import Any, Awaitable, Callable, TypeVar
from unittest.mock import Mock
@@ -117,3 +117,13 @@ class FakeResponse:
def deliverBody(self, protocol):
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
+
+
+# A small image used in some tests.
+#
+# Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
+SMALL_PNG = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+)
|