diff --git a/CHANGES.md b/CHANGES.md
index a299110a6b..a35f5aebc7 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,85 @@
+Synapse 0.33.3 (2018-08-22)
+===========================
+
+Bugfixes
+--------
+
+- Fix bug introduced in v0.33.3rc1 which made the ToS give a 500 error ([\#3732](https://github.com/matrix-org/synapse/issues/3732))
+
+
+Synapse 0.33.3rc2 (2018-08-21)
+==============================
+
+Bugfixes
+--------
+
+- Fix bug in v0.33.3rc1 which caused infinite loops and OOMs ([\#3723](https://github.com/matrix-org/synapse/issues/3723))
+
+
+Synapse 0.33.3rc1 (2018-08-21)
+==============================
+
+Features
+--------
+
+- Add support for the SNI extension to federation TLS connections. Thanks to @vojeroen! ([\#3439](https://github.com/matrix-org/synapse/issues/3439))
+- Add /_media/r0/config ([\#3184](https://github.com/matrix-org/synapse/issues/3184))
+- speed up /members API and add `at` and `membership` params as per MSC1227 ([\#3568](https://github.com/matrix-org/synapse/issues/3568))
+- implement `summary` block in /sync response as per MSC688 ([\#3574](https://github.com/matrix-org/synapse/issues/3574))
+- Add lazy-loading support to /messages as per MSC1227 ([\#3589](https://github.com/matrix-org/synapse/issues/3589))
+- Add ability to limit number of monthly active users on the server ([\#3633](https://github.com/matrix-org/synapse/issues/3633))
+- Support more federation endpoints on workers ([\#3653](https://github.com/matrix-org/synapse/issues/3653))
+- Basic support for room versioning ([\#3654](https://github.com/matrix-org/synapse/issues/3654))
+- Ability to disable client/server Synapse via conf toggle ([\#3655](https://github.com/matrix-org/synapse/issues/3655))
+- Ability to whitelist specific threepids against monthly active user limiting ([\#3662](https://github.com/matrix-org/synapse/issues/3662))
+- Add some metrics for the appservice and federation event sending loops ([\#3664](https://github.com/matrix-org/synapse/issues/3664))
+- Where server is disabled, block ability for locked out users to read new messages ([\#3670](https://github.com/matrix-org/synapse/issues/3670))
+- set admin uri via config, to be used in error messages where the user should contact the administrator ([\#3687](https://github.com/matrix-org/synapse/issues/3687))
+- Synapse's presence functionality can now be disabled with the "use_presence" configuration option. ([\#3694](https://github.com/matrix-org/synapse/issues/3694))
+- For resource limit blocked users, prevent writing into rooms ([\#3708](https://github.com/matrix-org/synapse/issues/3708))
+
+
+Bugfixes
+--------
+
+- Fix occasional glitches in the synapse_event_persisted_position metric ([\#3658](https://github.com/matrix-org/synapse/issues/3658))
+- Fix bug on deleting 3pid when using identity servers that don't support unbind API ([\#3661](https://github.com/matrix-org/synapse/issues/3661))
+- Make the tests pass on Twisted < 18.7.0 ([\#3676](https://github.com/matrix-org/synapse/issues/3676))
+- Don’t ship recaptcha_ajax.js, use it directly from Google ([\#3677](https://github.com/matrix-org/synapse/issues/3677))
+- Fixes test_reap_monthly_active_users so it passes under postgres ([\#3681](https://github.com/matrix-org/synapse/issues/3681))
+- Fix mau blocking calulation bug on login ([\#3689](https://github.com/matrix-org/synapse/issues/3689))
+- Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users ([\#3692](https://github.com/matrix-org/synapse/issues/3692))
+- Improve HTTP request logging to include all requests ([\#3700](https://github.com/matrix-org/synapse/issues/3700))
+- Avoid timing out requests while we are streaming back the response ([\#3701](https://github.com/matrix-org/synapse/issues/3701))
+- Support more federation endpoints on workers ([\#3705](https://github.com/matrix-org/synapse/issues/3705), [\#3713](https://github.com/matrix-org/synapse/issues/3713))
+- Fix "Starting db txn 'get_all_updated_receipts' from sentinel context" warning ([\#3710](https://github.com/matrix-org/synapse/issues/3710))
+- Fix bug where `state_cache` cache factor ignored environment variables ([\#3719](https://github.com/matrix-org/synapse/issues/3719))
+
+
+Deprecations and Removals
+-------------------------
+
+- The Shared-Secret registration method of the legacy v1/register REST endpoint has been removed. For a replacement, please see [the admin/register API documentation](https://github.com/matrix-org/synapse/blob/master/docs/admin_api/register_api.rst). ([\#3703](https://github.com/matrix-org/synapse/issues/3703))
+
+
+Internal Changes
+----------------
+
+- The test suite now can run under PostgreSQL. ([\#3423](https://github.com/matrix-org/synapse/issues/3423))
+- Refactor HTTP replication endpoints to reduce code duplication ([\#3632](https://github.com/matrix-org/synapse/issues/3632))
+- Tests now correctly execute on Python 3. ([\#3647](https://github.com/matrix-org/synapse/issues/3647))
+- Sytests can now be run inside a Docker container. ([\#3660](https://github.com/matrix-org/synapse/issues/3660))
+- Port over enough to Python 3 to allow the sytests to start. ([\#3668](https://github.com/matrix-org/synapse/issues/3668))
+- Update docker base image from alpine 3.7 to 3.8. ([\#3669](https://github.com/matrix-org/synapse/issues/3669))
+- Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7. ([\#3678](https://github.com/matrix-org/synapse/issues/3678))
+- Synapse's tests are now formatted with the black autoformatter. ([\#3679](https://github.com/matrix-org/synapse/issues/3679))
+- Implemented a new testing base class to reduce test boilerplate. ([\#3684](https://github.com/matrix-org/synapse/issues/3684))
+- Rename MAU prometheus metrics ([\#3690](https://github.com/matrix-org/synapse/issues/3690))
+- add new error type ResourceLimit ([\#3707](https://github.com/matrix-org/synapse/issues/3707))
+- Logcontexts for replication command handlers ([\#3709](https://github.com/matrix-org/synapse/issues/3709))
+- Update admin register API documentation to reference a real user ID. ([\#3712](https://github.com/matrix-org/synapse/issues/3712))
+
+
Synapse 0.33.2 (2018-08-09)
===========================
@@ -24,7 +106,7 @@ Features
Bugfixes
--------
-- Make /directory/list API return 404 for room not found instead of 400 ([\#2952](https://github.com/matrix-org/synapse/issues/2952))
+- Make /directory/list API return 404 for room not found instead of 400. Thanks to @fuzzmz! ([\#3620](https://github.com/matrix-org/synapse/issues/3620))
- Default inviter_display_name to mxid for email invites ([\#3391](https://github.com/matrix-org/synapse/issues/3391))
- Don't generate TURN credentials if no TURN config options are set ([\#3514](https://github.com/matrix-org/synapse/issues/3514))
- Correctly announce deleted devices over federation ([\#3520](https://github.com/matrix-org/synapse/issues/3520))
diff --git a/changelog.d/1491.feature b/changelog.d/1491.feature
deleted file mode 100644
index 77b6d6ca09..0000000000
--- a/changelog.d/1491.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add support for the SNI extension to federation TLS connections
\ No newline at end of file
diff --git a/changelog.d/3184.feature b/changelog.d/3184.feature
deleted file mode 100644
index 9f746a57a0..0000000000
--- a/changelog.d/3184.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add /_media/r0/config
diff --git a/changelog.d/3423.misc b/changelog.d/3423.misc
deleted file mode 100644
index 51768c6d14..0000000000
--- a/changelog.d/3423.misc
+++ /dev/null
@@ -1 +0,0 @@
-The test suite now can run under PostgreSQL.
diff --git a/changelog.d/3568.feature b/changelog.d/3568.feature
deleted file mode 100644
index 247f02ba4e..0000000000
--- a/changelog.d/3568.feature
+++ /dev/null
@@ -1 +0,0 @@
-speed up /members API and add `at` and `membership` params as per MSC1227
diff --git a/changelog.d/3574.feature b/changelog.d/3574.feature
deleted file mode 100644
index 87ac32df72..0000000000
--- a/changelog.d/3574.feature
+++ /dev/null
@@ -1 +0,0 @@
-implement `summary` block in /sync response as per MSC688
diff --git a/changelog.d/3589.feature b/changelog.d/3589.feature
deleted file mode 100644
index a8d7124719..0000000000
--- a/changelog.d/3589.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add lazy-loading support to /messages as per MSC1227
diff --git a/changelog.d/3632.misc b/changelog.d/3632.misc
deleted file mode 100644
index 9d64bbe83b..0000000000
--- a/changelog.d/3632.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor HTTP replication endpoints to reduce code duplication
diff --git a/changelog.d/3633.feature b/changelog.d/3633.feature
deleted file mode 100644
index 8007a04840..0000000000
--- a/changelog.d/3633.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add ability to limit number of monthly active users on the server
diff --git a/changelog.d/3647.misc b/changelog.d/3647.misc
deleted file mode 100644
index dbc66dae60..0000000000
--- a/changelog.d/3647.misc
+++ /dev/null
@@ -1 +0,0 @@
-Tests now correctly execute on Python 3.
diff --git a/changelog.d/3653.feature b/changelog.d/3653.feature
deleted file mode 100644
index 6c5422994f..0000000000
--- a/changelog.d/3653.feature
+++ /dev/null
@@ -1 +0,0 @@
-Support more federation endpoints on workers
diff --git a/changelog.d/3654.feature b/changelog.d/3654.feature
deleted file mode 100644
index 35c95580bc..0000000000
--- a/changelog.d/3654.feature
+++ /dev/null
@@ -1 +0,0 @@
-Basic support for room versioning
diff --git a/changelog.d/3655.feature b/changelog.d/3655.feature
deleted file mode 100644
index 1134e549e7..0000000000
--- a/changelog.d/3655.feature
+++ /dev/null
@@ -1 +0,0 @@
-Ability to disable client/server Synapse via conf toggle
diff --git a/changelog.d/3658.bugfix b/changelog.d/3658.bugfix
deleted file mode 100644
index 556011a150..0000000000
--- a/changelog.d/3658.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix occasional glitches in the synapse_event_persisted_position metric
diff --git a/changelog.d/3659.feature b/changelog.d/3659.feature
new file mode 100644
index 0000000000..a5b4821c09
--- /dev/null
+++ b/changelog.d/3659.feature
@@ -0,0 +1 @@
+Support profile API endpoints on workers
diff --git a/changelog.d/3660.misc b/changelog.d/3660.misc
deleted file mode 100644
index acd814c273..0000000000
--- a/changelog.d/3660.misc
+++ /dev/null
@@ -1 +0,0 @@
-Sytests can now be run inside a Docker container.
diff --git a/changelog.d/3661.bugfix b/changelog.d/3661.bugfix
deleted file mode 100644
index f2b4703d80..0000000000
--- a/changelog.d/3661.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix bug on deleting 3pid when using identity servers that don't support unbind API
diff --git a/changelog.d/3662.feature b/changelog.d/3662.feature
deleted file mode 100644
index daacef086d..0000000000
--- a/changelog.d/3662.feature
+++ /dev/null
@@ -1 +0,0 @@
-Ability to whitelist specific threepids against monthly active user limiting
diff --git a/changelog.d/3664.feature b/changelog.d/3664.feature
deleted file mode 100644
index 184dde9939..0000000000
--- a/changelog.d/3664.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add some metrics for the appservice and federation event sending loops
diff --git a/changelog.d/3669.misc b/changelog.d/3669.misc
deleted file mode 100644
index fc579ddc60..0000000000
--- a/changelog.d/3669.misc
+++ /dev/null
@@ -1 +0,0 @@
-Update docker base image from alpine 3.7 to 3.8.
diff --git a/changelog.d/3670.feature b/changelog.d/3670.feature
deleted file mode 100644
index ba00f2d2ec..0000000000
--- a/changelog.d/3670.feature
+++ /dev/null
@@ -1 +0,0 @@
-Where server is disabled, block ability for locked out users to read new messages
diff --git a/changelog.d/3673.misc b/changelog.d/3673.misc
new file mode 100644
index 0000000000..d672111fb9
--- /dev/null
+++ b/changelog.d/3673.misc
@@ -0,0 +1 @@
+Refactor state module to support multiple room versions
diff --git a/changelog.d/3676.bugfix b/changelog.d/3676.bugfix
deleted file mode 100644
index 7b23a2773a..0000000000
--- a/changelog.d/3676.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Make the tests pass on Twisted < 18.7.0
diff --git a/changelog.d/3677.bugfix b/changelog.d/3677.bugfix
deleted file mode 100644
index caa551627b..0000000000
--- a/changelog.d/3677.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Don’t ship recaptcha_ajax.js, use it directly from Google
diff --git a/changelog.d/3678.misc b/changelog.d/3678.misc
deleted file mode 100644
index 0d7c8da64a..0000000000
--- a/changelog.d/3678.misc
+++ /dev/null
@@ -1 +0,0 @@
-Rename synapse.util.async to synapse.util.async_helpers to mitigate async becoming a keyword on Python 3.7.
diff --git a/changelog.d/3679.misc b/changelog.d/3679.misc
deleted file mode 100644
index 1de0a0f2b4..0000000000
--- a/changelog.d/3679.misc
+++ /dev/null
@@ -1 +0,0 @@
-Synapse's tests are now formatted with the black autoformatter.
diff --git a/changelog.d/3681.bugfix b/changelog.d/3681.bugfix
deleted file mode 100644
index d18a69cd0c..0000000000
--- a/changelog.d/3681.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fixes test_reap_monthly_active_users so it passes under postgres
diff --git a/changelog.d/3684.misc b/changelog.d/3684.misc
deleted file mode 100644
index 4c013263c4..0000000000
--- a/changelog.d/3684.misc
+++ /dev/null
@@ -1 +0,0 @@
-Implemented a new testing base class to reduce test boilerplate.
diff --git a/changelog.d/3687.feature b/changelog.d/3687.feature
deleted file mode 100644
index 64b89f6411..0000000000
--- a/changelog.d/3687.feature
+++ /dev/null
@@ -1 +0,0 @@
-set admin uri via config, to be used in error messages where the user should contact the administrator
diff --git a/changelog.d/3689.bugfix b/changelog.d/3689.bugfix
deleted file mode 100644
index 934d039836..0000000000
--- a/changelog.d/3689.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix mau blocking calulation bug on login
diff --git a/changelog.d/3690.misc b/changelog.d/3690.misc
deleted file mode 100644
index 710add0243..0000000000
--- a/changelog.d/3690.misc
+++ /dev/null
@@ -1 +0,0 @@
-Rename MAU prometheus metrics
diff --git a/changelog.d/3692.bugfix b/changelog.d/3692.bugfix
deleted file mode 100644
index f44e13dca1..0000000000
--- a/changelog.d/3692.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix missing yield in synapse.storage.monthly_active_users.initialise_reserved_users
diff --git a/changelog.d/3705.bugfix b/changelog.d/3705.bugfix
deleted file mode 100644
index 6c5422994f..0000000000
--- a/changelog.d/3705.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Support more federation endpoints on workers
diff --git a/changelog.d/3707.misc b/changelog.d/3707.misc
deleted file mode 100644
index 8123ca6543..0000000000
--- a/changelog.d/3707.misc
+++ /dev/null
@@ -1 +0,0 @@
-add new error type ResourceLimit
diff --git a/changelog.d/3708.feature b/changelog.d/3708.feature
deleted file mode 100644
index 2f146ba62b..0000000000
--- a/changelog.d/3708.feature
+++ /dev/null
@@ -1 +0,0 @@
-For resource limit blocked users, prevent writing into rooms
diff --git a/changelog.d/3712.misc b/changelog.d/3712.misc
deleted file mode 100644
index 30f8c2af21..0000000000
--- a/changelog.d/3712.misc
+++ /dev/null
@@ -1 +0,0 @@
-Update admin register API documentation to reference a real user ID.
diff --git a/changelog.d/3722.bugfix b/changelog.d/3722.bugfix
new file mode 100644
index 0000000000..16cbaf76cb
--- /dev/null
+++ b/changelog.d/3722.bugfix
@@ -0,0 +1 @@
+Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues
diff --git a/changelog.d/3724.feature b/changelog.d/3724.feature
new file mode 100644
index 0000000000..1b374ccf47
--- /dev/null
+++ b/changelog.d/3724.feature
@@ -0,0 +1 @@
+Allow guests to use /rooms/:roomId/event/:eventId
diff --git a/changelog.d/3726.misc b/changelog.d/3726.misc
new file mode 100644
index 0000000000..c4f66ec998
--- /dev/null
+++ b/changelog.d/3726.misc
@@ -0,0 +1 @@
+Split the state_group_cache into member and non-member state events (and so speed up LL /sync)
diff --git a/changelog.d/3727.misc b/changelog.d/3727.misc
new file mode 100644
index 0000000000..0b83220d90
--- /dev/null
+++ b/changelog.d/3727.misc
@@ -0,0 +1 @@
+Log failure to authenticate remote servers as warnings (without stack traces)
diff --git a/changelog.d/3735.misc b/changelog.d/3735.misc
new file mode 100644
index 0000000000..f17004be71
--- /dev/null
+++ b/changelog.d/3735.misc
@@ -0,0 +1 @@
+Fix minor spelling error in federation client documentation.
diff --git a/docs/workers.rst b/docs/workers.rst
index ac9efb621f..81146a211f 100644
--- a/docs/workers.rst
+++ b/docs/workers.rst
@@ -241,6 +241,14 @@ regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/keys/upload
+If ``use_presence`` is False in the homeserver config, it can also handle REST
+endpoints matching the following regular expressions::
+
+ ^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status
+
+This "stub" presence handler will pass through ``GET`` request but make the
+``PUT`` effectively a no-op.
+
It will proxy any requests it cannot handle to the main synapse instance. It
must therefore be configured with the location of the main instance, via
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
@@ -257,6 +265,7 @@ Handles some event creation. It can handle REST endpoints matching::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
^/_matrix/client/(api/v1|r0|unstable)/join/
+ ^/_matrix/client/(api/v1|r0|unstable)/profile/
It will create events locally and then send them on to the main synapse
instance to be persisted and handled.
diff --git a/synapse/__init__.py b/synapse/__init__.py
index a14d578e36..e62901b761 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.33.2"
+__version__ = "0.33.3"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 4207a48afd..4ca40a0f71 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -211,7 +211,7 @@ class Auth(object):
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent",
default=[b""]
- )[0]
+ )[0].decode('ascii', 'surrogateescape')
if user and access_token and ip_addr:
yield self.store.insert_client_ip(
user_id=user.to_string(),
@@ -682,7 +682,7 @@ class Auth(object):
Returns:
bool: False if no access_token was given, True otherwise.
"""
- query_params = request.args.get("access_token")
+ query_params = request.args.get(b"access_token")
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@@ -698,7 +698,7 @@ class Auth(object):
401 since some of the old clients depended on auth errors returning
403.
Returns:
- str: The access_token
+ unicode: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
@@ -720,9 +720,9 @@ class Auth(object):
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
- parts = auth_headers[0].split(" ")
- if parts[0] == "Bearer" and len(parts) == 2:
- return parts[1]
+ parts = auth_headers[0].split(b" ")
+ if parts[0] == b"Bearer" and len(parts) == 2:
+ return parts[1].decode('ascii')
else:
raise AuthError(
token_not_found_http_status,
@@ -738,7 +738,7 @@ class Auth(object):
errcode=Codes.MISSING_TOKEN
)
- return query_params[0]
+ return query_params[0].decode('ascii')
@defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id):
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index a67862f4ed..c2630c4c64 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -98,13 +98,17 @@ class ThirdPartyEntityKind(object):
LOCATION = "location"
+class RoomVersions(object):
+ V1 = "1"
+ VDH_TEST = "vdh-test-version"
+
+
# the version we will give rooms which are created on this server
-DEFAULT_ROOM_VERSION = "1"
+DEFAULT_ROOM_VERSION = RoomVersions.V1
# vdh-test-version is a placeholder to get room versioning support working and tested
# until we have a working v2.
-KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"}
-
+KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST}
ServerNoticeMsgType = "m.server_notice"
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 06cc8d90b8..3bb5b3da37 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -72,7 +72,7 @@ class Ratelimiter(object):
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
- for user_id in self.message_counts.keys():
+ for user_id in list(self.message_counts.keys()):
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 391bd14c5c..7c866e246a 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -140,7 +140,7 @@ def listen_metrics(bind_addresses, port):
logger.info("Metrics now reporting on %s:%d", host, port)
-def listen_tcp(bind_addresses, port, factory, backlog=50):
+def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
"""
Create a TCP socket for a port and several addresses
"""
@@ -156,7 +156,9 @@ def listen_tcp(bind_addresses, port, factory, backlog=50):
check_bind_error(e, address, bind_addresses)
-def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
+def listen_ssl(
+ bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
+):
"""
Create an SSL socket for a port and several addresses
"""
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 9a37384fb7..3348a8ec6d 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -117,8 +117,9 @@ class ASReplicationHandler(ReplicationClientHandler):
super(ASReplicationHandler, self).__init__(hs.get_datastore())
self.appservice_handler = hs.get_application_service_handler()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
+ yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 03d39968a8..a34c89fa99 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -45,6 +45,11 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1.profile import (
+ ProfileAvatarURLRestServlet,
+ ProfileDisplaynameRestServlet,
+ ProfileRestServlet,
+)
from synapse.rest.client.v1.room import (
JoinRoomAliasServlet,
RoomMembershipRestServlet,
@@ -53,6 +58,7 @@ from synapse.rest.client.v1.room import (
)
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
+from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
@@ -62,6 +68,9 @@ logger = logging.getLogger("synapse.app.event_creator")
class EventCreatorSlavedStore(
+ # FIXME(#3714): We need to add UserDirectoryStore as we write directly
+ # rather than going via the correct worker.
+ UserDirectoryStore,
DirectoryStore,
SlavedTransactionStore,
SlavedProfileStore,
@@ -101,6 +110,9 @@ class EventCreatorServer(HomeServer):
RoomMembershipRestServlet(self).register(resource)
RoomStateEventRestServlet(self).register(resource)
JoinRoomAliasServlet(self).register(resource)
+ ProfileAvatarURLRestServlet(self).register(resource)
+ ProfileDisplaynameRestServlet(self).register(resource)
+ ProfileRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 7a4310ca18..d59007099b 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -144,8 +144,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
self.send_handler = FederationSenderHandler(hs, self)
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(FederationSenderReplicationHandler, self).on_rdata(
+ yield super(FederationSenderReplicationHandler, self).on_rdata(
stream_name, token, rows
)
self.send_handler.process_replication_rows(stream_name, token, rows)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 671fbbcb2a..8d484c1cd4 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -38,6 +38,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
@@ -49,6 +50,35 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy")
+class PresenceStatusStubServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
+
+ def __init__(self, hs):
+ super(PresenceStatusStubServlet, self).__init__(hs)
+ self.http_client = hs.get_simple_http_client()
+ self.auth = hs.get_auth()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
+ headers = {
+ "Authorization": auth_headers,
+ }
+ result = yield self.http_client.get_json(
+ self.main_uri + request.uri,
+ headers=headers,
+ )
+ defer.returnValue((200, result))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, user_id):
+ yield self.auth.get_user_by_req(request)
+ defer.returnValue((200, {}))
+
+
class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
@@ -135,6 +165,12 @@ class FrontendProxyServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
KeyUploadServlet(self).register(resource)
+
+ # If presence is disabled, use the stub servlet that does
+ # not allow sending presence
+ if not self.config.use_presence:
+ PresenceStatusStubServlet(self).register(resource)
+
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
@@ -153,7 +189,8 @@ class FrontendProxyServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
+ reactor=self.get_reactor()
)
logger.info("Synapse client reader now listening on port %d", port)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 9295a51d5b..a4fc7e91fa 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -148,8 +148,9 @@ class PusherReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+ yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks
@@ -162,11 +163,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(
+ self.pusher_pool.on_new_notifications(
token, token,
)
elif stream_name == "receipts":
- yield self.pusher_pool.on_new_receipts(
+ self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows)
)
except Exception:
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index e201f18efd..27e1998660 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -114,7 +114,10 @@ class SynchrotronPresence(object):
logger.info("Presence process_id is %r", self.process_id)
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
- self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
+ if self.hs.config.use_presence:
+ self.hs.get_tcp_replication().send_user_sync(
+ user_id, is_syncing, last_sync_ms
+ )
def mark_as_coming_online(self, user_id):
"""A user has started syncing. Send a UserSync to the master, unless they
@@ -211,10 +214,13 @@ class SynchrotronPresence(object):
yield self.notify_from_replication(states, stream_id)
def get_currently_syncing_users(self):
- return [
- user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
- if count > 0
- ]
+ if self.hs.config.use_presence:
+ return [
+ user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
+ if count > 0
+ ]
+ else:
+ return set()
class SynchrotronTyping(object):
@@ -332,8 +338,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
+ yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self):
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index cb78de8834..1388a42b59 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -169,8 +169,9 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
self.user_directory = hs.get_user_directory_handler()
+ @defer.inlineCallbacks
def on_rdata(self, stream_name, token, rows):
- super(UserDirectoryReplicationHandler, self).on_rdata(
+ yield super(UserDirectoryReplicationHandler, self).on_rdata(
stream_name, token, rows
)
if stream_name == "current_state_deltas":
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index cfc20dcccf..3f187adfc8 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
if log_file:
# TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler(
- log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
+ log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
+ encoding='utf8'
)
def sighup(signum, stack):
diff --git a/synapse/config/server.py b/synapse/config/server.py
index a41c48e69c..68a612e594 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -49,6 +49,9 @@ class ServerConfig(Config):
# "disable" federation
self.send_federation = config.get("send_federation", True)
+ # Whether to enable user presence.
+ self.use_presence = config.get("use_presence", True)
+
# Whether to update the user directory or not. This should be set to
# false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True)
@@ -250,6 +253,9 @@ class ServerConfig(Config):
# hard limit.
soft_file_limit: 0
+ # Set to false to disable presence tracking on this homeserver.
+ use_presence: true
+
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c20a32096a..e94400b8e2 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -18,7 +18,9 @@ import logging
from canonicaljson import json
from twisted.internet import defer, reactor
+from twisted.internet.error import ConnectError
from twisted.internet.protocol import Factory
+from twisted.names.error import DomainError
from twisted.web.http import HTTPClient
from synapse.http.endpoint import matrix_federation_endpoint
@@ -47,12 +49,14 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e:
- logger.exception("Error getting key for %r" % (server_name,))
+ logger.warn("Error getting key for %r: %s", server_name, e)
if e.status.startswith("4"):
# Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name)
+ except (ConnectError, DomainError) as e:
+ logger.warn("Error getting key for %r: %s", server_name, e)
except Exception as e:
- logger.exception(e)
+ logger.exception("Error getting key for %r", server_name)
raise IOError("Cannot get key for %r" % server_name)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f603c8a368..94d7423d01 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -58,6 +58,7 @@ class TransactionQueue(object):
"""
def __init__(self, hs):
+ self.hs = hs
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -308,6 +309,9 @@ class TransactionQueue(object):
Args:
states (list(UserPresenceState))
"""
+ if not self.hs.config.use_presence:
+ # No-op if presence is disabled.
+ return
# First we queue up the new presence by user ID, so multiple presence
# updates in quick successtion are correctly handled
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index b4fbe2c9d5..1054441ca5 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -106,7 +106,7 @@ class TransportLayerClient(object):
dest (str)
room_id (str)
event_tuples (list)
- limt (int)
+ limit (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 77969a4f38..7a993fd1cf 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -261,10 +261,10 @@ class BaseFederationServlet(object):
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
- logger.exception("authenticate_request failed")
+ logger.warn("authenticate_request failed: missing authentication")
raise
- except Exception:
- logger.exception("authenticate_request failed")
+ except Exception as e:
+ logger.warn("authenticate_request failed: %s", e)
raise
if origin:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f38b393e4a..0ebf0fd188 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -291,8 +291,9 @@ class FederationHandler(BaseHandler):
ev_ids, get_prev_content=False, check_redacted=False
)
+ room_version = yield self.store.get_room_version(pdu.room_id)
state_map = yield resolve_events_with_factory(
- state_groups, {pdu.event_id: pdu}, fetch
+ room_version, state_groups, {pdu.event_id: pdu}, fetch
)
state = (yield self.store.get_events(state_map.values())).values()
@@ -1828,7 +1829,10 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d
})
+ room_version = yield self.store.get_room_version(event.room_id)
+
new_state = self.state_handler.resolve_events(
+ room_version,
[list(local_view.values()), list(remote_view.values())],
event
)
@@ -2386,8 +2390,7 @@ class FederationHandler(BaseHandler):
extra_users=extra_users
)
- logcontext.run_in_background(
- self.pusher_pool.on_new_notifications,
+ self.pusher_pool.on_new_notifications(
event_stream_id, max_stream_id,
)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 1fb17fd9a5..e009395207 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -372,6 +372,10 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_presence():
+ # If presence is disabled, return an empty list
+ if not self.hs.config.use_presence:
+ defer.returnValue([])
+
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4d006df63c..e484061cc0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -778,11 +778,8 @@ class EventCreationHandler(object):
event, context=context
)
- # this intentionally does not yield: we don't care about the result
- # and don't need to wait for it.
- run_in_background(
- self.pusher_pool.on_new_notifications,
- event_stream_id, max_stream_id
+ self.pusher_pool.on_new_notifications(
+ event_stream_id, max_stream_id,
)
def _notify():
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3671d24f60..ba3856674d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -395,6 +395,10 @@ class PresenceHandler(object):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
+ # If presence is disabled, no-op
+ if not self.hs.config.use_presence:
+ return
+
user_id = user.to_string()
bump_active_time_counter.inc()
@@ -424,6 +428,11 @@ class PresenceHandler(object):
Useful for streams that are not associated with an actual
client that is being used by a user.
"""
+ # Override if it should affect the user's presence, if presence is
+ # disabled.
+ if not self.hs.config.use_presence:
+ affect_presence = False
+
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
@@ -469,13 +478,16 @@ class PresenceHandler(object):
Returns:
set(str): A set of user_id strings.
"""
- syncing_user_ids = {
- user_id for user_id, count in self.user_to_num_current_syncs.items()
- if count
- }
- for user_ids in self.external_process_to_current_syncs.values():
- syncing_user_ids.update(user_ids)
- return syncing_user_ids
+ if self.hs.config.use_presence:
+ syncing_user_ids = {
+ user_id for user_id, count in self.user_to_num_current_syncs.items()
+ if count
+ }
+ for user_ids in self.external_process_to_current_syncs.values():
+ syncing_user_ids.update(user_ids)
+ return syncing_user_ids
+ else:
+ return set()
@defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 9af2e8f869..75b8b7ce6a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -32,12 +32,16 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
-class ProfileHandler(BaseHandler):
- PROFILE_UPDATE_MS = 60 * 1000
- PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
+class BaseProfileHandler(BaseHandler):
+ """Handles fetching and updating user profile information.
+
+ BaseProfileHandler can be instantiated directly on workers and will
+ delegate to master when necessary. The master process should use the
+ subclass MasterProfileHandler
+ """
def __init__(self, hs):
- super(ProfileHandler, self).__init__(hs)
+ super(BaseProfileHandler, self).__init__(hs)
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -46,11 +50,6 @@ class ProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
- if hs.config.worker_app is None:
- self.clock.looping_call(
- self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
- )
-
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@@ -282,6 +281,20 @@ class ProfileHandler(BaseHandler):
room_id, str(e.message)
)
+
+class MasterProfileHandler(BaseProfileHandler):
+ PROFILE_UPDATE_MS = 60 * 1000
+ PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
+
+ def __init__(self, hs):
+ super(MasterProfileHandler, self).__init__(hs)
+
+ assert hs.config.worker_app is None
+
+ self.clock.looping_call(
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ )
+
def _start_update_remote_profile_cache(self):
return run_as_background_process(
"Update remote profile", self._update_remote_profile_cache,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index cb905a3903..a6f3181f09 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util import logcontext
-from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler
@@ -116,16 +115,15 @@ class ReceiptsHandler(BaseHandler):
affected_room_ids = list(set([r["room_id"] for r in receipts]))
- with PreserveLoggingContext():
- self.notifier.on_new_event(
- "receipt_key", max_batch_id, rooms=affected_room_ids
- )
- # Note that the min here shouldn't be relied upon to be accurate.
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ self.notifier.on_new_event(
+ "receipt_key", max_batch_id, rooms=affected_room_ids
+ )
+ # Note that the min here shouldn't be relied upon to be accurate.
+ self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids,
+ )
- defer.returnValue(True)
+ defer.returnValue(True)
@logcontext.preserve_fn # caller should not yield on this
@defer.inlineCallbacks
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index fb94b5d7d4..f643619047 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -344,6 +344,7 @@ class RoomMemberHandler(object):
latest_event_ids = (
event_id for (event_id, _, _) in prev_events_and_hashes
)
+
current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ac3edf0cc9..648debc8aa 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -185,6 +185,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
class SyncHandler(object):
def __init__(self, hs):
+ self.hs_config = hs.config
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
@@ -860,7 +861,7 @@ class SyncHandler(object):
since_token is None and
sync_config.filter_collection.blocks_all_presence()
)
- if not block_all_presence_data:
+ if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_users
)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 37dda64587..d8413d6aa7 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -119,6 +119,8 @@ class UserDirectoryHandler(object):
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
+ # FIXME(#3714): We should probably do this in the same worker as all
+ # the other changes.
yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None,
)
@@ -127,6 +129,8 @@ class UserDirectoryHandler(object):
def handle_user_deactivated(self, user_id):
"""Called when a user ID is deactivated
"""
+ # FIXME(#3714): We should probably do this in the same worker as all
+ # the other changes.
yield self.store.remove_from_user_dir(user_id)
yield self.store.remove_from_user_in_public_room(user_id)
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 588e280571..72c2654678 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+import threading
from prometheus_client.core import Counter, Histogram
@@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter(
# The set of all in flight requests, set[RequestMetrics]
_in_flight_requests = set()
+# Protects the _in_flight_requests set from concurrent accesss
+_in_flight_requests_lock = threading.Lock()
+
def _get_in_flight_counts():
"""Returns a count of all in flight requests by (method, server_name)
@@ -120,7 +124,8 @@ def _get_in_flight_counts():
"""
# Cast to a list to prevent it changing while the Prometheus
# thread is collecting metrics
- reqs = list(_in_flight_requests)
+ with _in_flight_requests_lock:
+ reqs = list(_in_flight_requests)
for rm in reqs:
rm.update_metrics()
@@ -154,10 +159,12 @@ class RequestMetrics(object):
# to the "in flight" metrics.
self._request_stats = self.start_context.get_resource_usage()
- _in_flight_requests.add(self)
+ with _in_flight_requests_lock:
+ _in_flight_requests.add(self)
def stop(self, time_sec, request):
- _in_flight_requests.discard(self)
+ with _in_flight_requests_lock:
+ _in_flight_requests.discard(self)
context = LoggingContext.current_context()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6dacb31037..2d5c23e673 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -25,8 +25,9 @@ from canonicaljson import encode_canonical_json, encode_pretty_printed_json, jso
from twisted.internet import defer
from twisted.python import failure
-from twisted.web import resource, server
+from twisted.web import resource
from twisted.web.server import NOT_DONE_YET
+from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo
import synapse.events
@@ -37,10 +38,13 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.request_metrics import requests_counter
from synapse.util.caches import intern_dict
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.metrics import Measure
+from synapse.util.logcontext import preserve_fn
+
+if PY3:
+ from io import BytesIO
+else:
+ from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__)
@@ -60,11 +64,10 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling.
- Also adds logging as per wrap_request_handler_with_logging.
+ Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
- where "self" must have a "clock" attribute (and "request" must be a
- SynapseRequest).
+ where "request" must be a SynapseRequest.
The handler must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use
@@ -108,24 +111,23 @@ def wrap_json_request_handler(h):
pretty_print=_request_user_agent_is_curl(request),
)
- return wrap_request_handler_with_logging(wrapped_request_handler)
+ return wrap_async_request_handler(wrapped_request_handler)
def wrap_html_request_handler(h):
"""Wraps a request handler method with exception handling.
- Also adds logging as per wrap_request_handler_with_logging.
+ Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
- where "self" must have a "clock" attribute (and "request" must be a
- SynapseRequest).
+ where "request" must be a SynapseRequest.
"""
def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
return d
- return wrap_request_handler_with_logging(wrapped_request_handler)
+ return wrap_async_request_handler(wrapped_request_handler)
def _return_html_error(f, request):
@@ -170,46 +172,26 @@ def _return_html_error(f, request):
finish_request(request)
-def wrap_request_handler_with_logging(h):
- """Wraps a request handler to provide logging and metrics
+def wrap_async_request_handler(h):
+ """Wraps an async request handler so that it calls request.processing.
+
+ This helps ensure that work done by the request handler after the request is completed
+ is correctly recorded against the request metrics/logs.
The handler method must have a signature of "handle_foo(self, request)",
- where "self" must have a "clock" attribute (and "request" must be a
- SynapseRequest).
+ where "request" must be a SynapseRequest.
- As well as calling `request.processing` (which will log the response and
- duration for this request), the wrapped request handler will insert the
- request id into the logging context.
+ The handler may return a deferred, in which case the completion of the request isn't
+ logged until the deferred completes.
"""
@defer.inlineCallbacks
- def wrapped_request_handler(self, request):
- """
- Args:
- self:
- request (synapse.http.site.SynapseRequest):
- """
+ def wrapped_async_request_handler(self, request):
+ with request.processing():
+ yield h(self, request)
- request_id = request.get_request_id()
- with LoggingContext(request_id) as request_context:
- request_context.request = request_id
- with Measure(self.clock, "wrapped_request_handler"):
- # we start the request metrics timer here with an initial stab
- # at the servlet name. For most requests that name will be
- # JsonResource (or a subclass), and JsonResource._async_render
- # will update it once it picks a servlet.
- servlet_name = self.__class__.__name__
- with request.processing(servlet_name):
- with PreserveLoggingContext(request_context):
- d = defer.maybeDeferred(h, self, request)
-
- # record the arrival of the request *after*
- # dispatching to the handler, so that the handler
- # can update the servlet name in the request
- # metrics
- requests_counter.labels(request.method,
- request.request_metrics.name).inc()
- yield d
- return wrapped_request_handler
+ # we need to preserve_fn here, because the synchronous render method won't yield for
+ # us (obviously)
+ return preserve_fn(wrapped_async_request_handler)
class HttpServer(object):
@@ -272,7 +254,7 @@ class JsonResource(HttpServer, resource.Resource):
""" This gets called by twisted every time someone sends us a request.
"""
self._async_render(request)
- return server.NOT_DONE_YET
+ return NOT_DONE_YET
@wrap_json_request_handler
@defer.inlineCallbacks
@@ -413,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
return
if pretty_print:
- json_bytes = (encode_pretty_printed_json(json_object) + "\n"
- ).encode("utf-8")
+ json_bytes = encode_pretty_printed_json(json_object) + b"\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
@@ -450,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
if send_cors:
set_cors_headers(request)
- request.write(json_bytes)
- finish_request(request)
+ # todo: we can almost certainly avoid this copy and encode the json straight into
+ # the bytesIO, but it would involve faffing around with string->bytes wrappers.
+ bytes_io = BytesIO(json_bytes)
+
+ producer = NoRangeStaticProducer(request, bytes_io)
+ producer.start()
return NOT_DONE_YET
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 69f7085291..a1e4b88e6d 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
+ name (bytes/unicode): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
try:
return int(args[name][0])
@@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
+ name (bytes/unicode): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
@@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
try:
return {
- "true": True,
- "false": False,
+ b"true": True,
+ b"false": False,
}[args[name][0]]
except Exception:
message = (
@@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string(request, name, default=None, required=False,
- allowed_values=None, param_type="string"):
- """Parse a string parameter from the request query string.
+ allowed_values=None, param_type="string", encoding='ascii'):
+ """
+ Parse a string parameter from the request query string.
+
+ If encoding is not None, the content of the query param will be
+ decoded to Unicode using the encoding, otherwise it will be encoded
Args:
request: the twisted HTTP request.
- name (str): the name of the query parameter.
- default (str|None): value to use if the parameter is absent, defaults
- to None.
+ name (bytes/unicode): the name of the query parameter.
+ default (bytes/unicode|None): value to use if the parameter is absent,
+ defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
- allowed_values (list[str]): List of allowed values for the string,
- or None if any value is allowed, defaults to None
+ allowed_values (list[bytes/unicode]): List of allowed values for the
+ string, or None if any value is allowed, defaults to None. Must be
+ the same type as name, if given.
+ encoding: The encoding to decode the name to, and decode the string
+ content with.
Returns:
- str|None: A string value or the default.
+ bytes/unicode|None: A string value or the default. Unicode if encoding
+ was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
@@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
is not one of those allowed values.
"""
return parse_string_from_args(
- request.args, name, default, required, allowed_values, param_type,
+ request.args, name, default, required, allowed_values, param_type, encoding
)
def parse_string_from_args(args, name, default=None, required=False,
- allowed_values=None, param_type="string"):
+ allowed_values=None, param_type="string", encoding='ascii'):
+
+ if not isinstance(name, bytes):
+ name = name.encode('ascii')
+
if name in args:
value = args[name][0]
+
+ if encoding:
+ value = value.decode(encoding)
+
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
@@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else:
+
+ if encoding and isinstance(default, bytes):
+ return default.decode(encoding)
+
return default
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5fd30a4c2c..88ed3714f9 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import contextlib
import logging
import time
@@ -19,8 +18,8 @@ import time
from twisted.web.server import Request, Site
from synapse.http import redact_uri
-from synapse.http.request_metrics import RequestMetrics
-from synapse.util.logcontext import ContextResourceUsage, LoggingContext
+from synapse.http.request_metrics import RequestMetrics, requests_counter
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
@@ -34,25 +33,43 @@ class SynapseRequest(Request):
It extends twisted's twisted.web.server.Request, and adds:
* Unique request ID
+ * A log context associated with the request
* Redaction of access_token query-params in __repr__
* Logging at start and end
* Metrics to record CPU, wallclock and DB time by endpoint.
- It provides a method `processing` which should be called by the Resource
- which is handling the request, and returns a context manager.
+ It also provides a method `processing`, which returns a context manager. If this
+ method is called, the request won't be logged until the context manager is closed;
+ this is useful for asynchronous request handlers which may go on processing the
+ request even after the client has disconnected.
+ Attributes:
+ logcontext(LoggingContext) : the log context for this request
"""
def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = site
- self._channel = channel
+ self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
+ # we can't yet create the logcontext, as we don't know the method.
+ self.logcontext = None
+
global _next_request_seq
self.request_seq = _next_request_seq
_next_request_seq += 1
+ # whether an asynchronous request handler has called processing()
+ self._is_processing = False
+
+ # the time when the asynchronous request handler completed its processing
+ self._processing_finished_time = None
+
+ # what time we finished sending the response to the client (or the connection
+ # dropped)
+ self.finish_time = None
+
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
@@ -74,11 +91,116 @@ class SynapseRequest(Request):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
def render(self, resrc):
+ # this is called once a Resource has been found to serve the request; in our
+ # case the Resource in question will normally be a JsonResource.
+
+ # create a LogContext for this request
+ request_id = self.get_request_id()
+ logcontext = self.logcontext = LoggingContext(request_id)
+ logcontext.request = request_id
+
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
- return Request.render(self, resrc)
+
+ with PreserveLoggingContext(self.logcontext):
+ # we start the request metrics timer here with an initial stab
+ # at the servlet name. For most requests that name will be
+ # JsonResource (or a subclass), and JsonResource._async_render
+ # will update it once it picks a servlet.
+ servlet_name = resrc.__class__.__name__
+ self._started_processing(servlet_name)
+
+ Request.render(self, resrc)
+
+ # record the arrival of the request *after*
+ # dispatching to the handler, so that the handler
+ # can update the servlet name in the request
+ # metrics
+ requests_counter.labels(self.method,
+ self.request_metrics.name).inc()
+
+ @contextlib.contextmanager
+ def processing(self):
+ """Record the fact that we are processing this request.
+
+ Returns a context manager; the correct way to use this is:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ with request.processing("FooServlet"):
+ yield really_handle_the_request()
+
+ Once the context manager is closed, the completion of the request will be logged,
+ and the various metrics will be updated.
+ """
+ if self._is_processing:
+ raise RuntimeError("Request is already processing")
+ self._is_processing = True
+
+ try:
+ yield
+ except Exception:
+ # this should already have been caught, and sent back to the client as a 500.
+ logger.exception("Asynchronous messge handler raised an uncaught exception")
+ finally:
+ # the request handler has finished its work and either sent the whole response
+ # back, or handed over responsibility to a Producer.
+
+ self._processing_finished_time = time.time()
+ self._is_processing = False
+
+ # if we've already sent the response, log it now; otherwise, we wait for the
+ # response to be sent.
+ if self.finish_time is not None:
+ self._finished_processing()
+
+ def finish(self):
+ """Called when all response data has been written to this Request.
+
+ Overrides twisted.web.server.Request.finish to record the finish time and do
+ logging.
+ """
+ self.finish_time = time.time()
+ Request.finish(self)
+ if not self._is_processing:
+ with PreserveLoggingContext(self.logcontext):
+ self._finished_processing()
+
+ def connectionLost(self, reason):
+ """Called when the client connection is closed before the response is written.
+
+ Overrides twisted.web.server.Request.connectionLost to record the finish time and
+ do logging.
+ """
+ self.finish_time = time.time()
+ Request.connectionLost(self, reason)
+
+ # we only get here if the connection to the client drops before we send
+ # the response.
+ #
+ # It's useful to log it here so that we can get an idea of when
+ # the client disconnects.
+ with PreserveLoggingContext(self.logcontext):
+ logger.warn(
+ "Error processing request %r: %s %s", self, reason.type, reason.value,
+ )
+
+ if not self._is_processing:
+ self._finished_processing()
def _started_processing(self, servlet_name):
+ """Record the fact that we are processing this request.
+
+ This will log the request's arrival. Once the request completes,
+ be sure to call finished_processing.
+
+ Args:
+ servlet_name (str): the name of the servlet which will be
+ processing this request. This is used in the metrics.
+
+ It is possible to update this afterwards by updating
+ self.request_metrics.name.
+ """
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
@@ -94,18 +216,32 @@ class SynapseRequest(Request):
)
def _finished_processing(self):
- try:
- context = LoggingContext.current_context()
- usage = context.get_resource_usage()
- except Exception:
- usage = ContextResourceUsage()
+ """Log the completion of this request and update the metrics
+ """
+
+ if self.logcontext is None:
+ # this can happen if the connection closed before we read the
+ # headers (so render was never called). In that case we'll already
+ # have logged a warning, so just bail out.
+ return
+
+ usage = self.logcontext.get_resource_usage()
+
+ if self._processing_finished_time is None:
+ # we completed the request without anything calling processing()
+ self._processing_finished_time = time.time()
- end_time = time.time()
+ # the time between receiving the request and the request handler finishing
+ processing_time = self._processing_finished_time - self.start_time
+
+ # the time between the request handler finishing and the response being sent
+ # to the client (nb may be negative)
+ response_send_time = self.finish_time - self._processing_finished_time
# need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity
- if authenticated_entity is not None:
+ if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header.
@@ -116,22 +252,31 @@ class SynapseRequest(Request):
user_agent = self.get_user_agent()
if user_agent is not None:
user_agent = user_agent.decode("utf-8", "replace")
+ else:
+ user_agent = "-"
+
+ code = str(self.code)
+ if not self.finished:
+ # we didn't send the full response before we gave up (presumably because
+ # the connection dropped)
+ code += "!"
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
+ " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(),
self.site.site_tag,
authenticated_entity,
- end_time - self.start_time,
+ processing_time,
+ response_send_time,
usage.ru_utime,
usage.ru_stime,
usage.db_sched_duration_sec,
usage.db_txn_duration_sec,
int(usage.db_txn_count),
self.sentLength,
- self.code,
+ code,
self.method,
self.get_redacted_uri(),
self.clientproto,
@@ -140,38 +285,10 @@ class SynapseRequest(Request):
)
try:
- self.request_metrics.stop(end_time, self)
+ self.request_metrics.stop(self.finish_time, self)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
- @contextlib.contextmanager
- def processing(self, servlet_name):
- """Record the fact that we are processing this request.
-
- Returns a context manager; the correct way to use this is:
-
- @defer.inlineCallbacks
- def handle_request(request):
- with request.processing("FooServlet"):
- yield really_handle_the_request()
-
- This will log the request's arrival. Once the context manager is
- closed, the completion of the request will be logged, and the various
- metrics will be updated.
-
- Args:
- servlet_name (str): the name of the servlet which will be
- processing this request. This is used in the metrics.
-
- It is possible to update this afterwards by updating
- self.request_metrics.servlet_name.
- """
- # TODO: we should probably just move this into render() and finish(),
- # to save having to call a separate method.
- self._started_processing(servlet_name)
- yield
- self._finished_processing()
-
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
@@ -217,7 +334,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
- self.server_version_string = server_version_string
+ self.server_version_string = server_version_string.encode('ascii')
def log(self, request):
pass
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index ce678d5f75..167167be0a 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import threading
+
import six
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
@@ -78,6 +80,9 @@ _background_process_counts = dict() # type: dict[str, int]
# of process descriptions that no longer have any active processes.
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
+# A lock that covers the above dicts
+_bg_metrics_lock = threading.Lock()
+
class _Collector(object):
"""A custom metrics collector for the background process metrics.
@@ -92,7 +97,11 @@ class _Collector(object):
labels=["name"],
)
- for desc, processes in six.iteritems(_background_processes):
+ # We copy the dict so that it doesn't change from underneath us
+ with _bg_metrics_lock:
+ _background_processes_copy = dict(_background_processes)
+
+ for desc, processes in six.iteritems(_background_processes_copy):
background_process_in_flight_count.add_metric(
(desc,), len(processes),
)
@@ -167,19 +176,26 @@ def run_as_background_process(desc, func, *args, **kwargs):
"""
@defer.inlineCallbacks
def run():
- count = _background_process_counts.get(desc, 0)
- _background_process_counts[desc] = count + 1
+ with _bg_metrics_lock:
+ count = _background_process_counts.get(desc, 0)
+ _background_process_counts[desc] = count + 1
+
_background_process_start_count.labels(desc).inc()
with LoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count)
proc = _BackgroundProcess(desc, context)
- _background_processes.setdefault(desc, set()).add(proc)
+
+ with _bg_metrics_lock:
+ _background_processes.setdefault(desc, set()).add(proc)
+
try:
yield func(*args, **kwargs)
finally:
proc.update_metrics()
- _background_processes[desc].remove(proc)
+
+ with _bg_metrics_lock:
+ _background_processes[desc].remove(proc)
with PreserveLoggingContext():
return run()
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 36bb5bbc65..9f7d5ef217 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -18,6 +18,7 @@ import logging
from twisted.internet import defer
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push.pusher import PusherFactory
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@@ -122,8 +123,14 @@ class PusherPool:
p['app_id'], p['pushkey'], p['user_name'],
)
- @defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
+ run_as_background_process(
+ "on_new_notifications",
+ self._on_new_notifications, min_stream_id, max_stream_id,
+ )
+
+ @defer.inlineCallbacks
+ def _on_new_notifications(self, min_stream_id, max_stream_id):
try:
users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
@@ -147,8 +154,14 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- @defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+ run_as_background_process(
+ "on_new_receipts",
+ self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
+ )
+
+ @defer.inlineCallbacks
+ def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2ddd18f73b..64a79da162 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
edu_content = content["content"]
logger.info(
- "Got %r edu from $s",
+ "Got %r edu from %s",
edu_type, origin,
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 970e94313e..cbe9645817 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -107,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
logger.info("Received rdata %s -> %s", stream_name, token)
- self.store.process_replication_rows(stream_name, token, rows)
+ return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
@@ -115,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
- self.store.process_replication_rows(stream_name, token, [])
+ return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f3908df642..327556f6a1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -59,6 +59,12 @@ class Command(object):
"""
return self.data
+ def get_logcontext_id(self):
+ """Get a suitable string for the logcontext when processing this command"""
+
+ # by default, we just use the command name.
+ return self.NAME
+
class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name.
@@ -116,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row),
))
+ def get_logcontext_id(self):
+ return "RDATA-" + self.stream_name
+
class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without
@@ -190,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self):
return " ".join((self.stream_name, str(self.token),))
+ def get_logcontext_id(self):
+ return "REPLICATE-" + self.stream_name
+
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index dec5ac0913..74e892c104 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -63,6 +63,8 @@ from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from .commands import (
@@ -222,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
try:
- getattr(self, "on_%s" % (cmd_name,))(cmd)
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(),
+ getattr(self, "on_%s" % (cmd_name,)),
+ cmd,
+ )
except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@@ -387,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data
def on_USER_SYNC(self, cmd):
- self.streamer.on_user_sync(
+ return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
)
@@ -397,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in iterkeys(self.streamer.streams_by_name):
- self.subscribe_to_stream(stream, token)
+ deferreds = [
+ run_in_background(
+ self.subscribe_to_stream,
+ stream, token,
+ )
+ for stream in iterkeys(self.streamer.streams_by_name)
+ ]
+
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
else:
- self.subscribe_to_stream(stream_name, token)
+ return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
+ return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd):
- self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+ return self.streamer.on_remove_pusher(
+ cmd.app_id, cmd.push_key, cmd.user_id,
+ )
def on_INVALIDATE_CACHE(self, cmd):
- self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
- self.streamer.on_user_ip(
+ return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen,
)
@@ -542,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
-
- self.handler.on_rdata(stream_name, cmd.token, rows)
+ return self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
- self.handler.on_position(cmd.stream_name, cmd.token)
+ return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
+ return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 511e96ab00..48c17f1b6d 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -53,7 +53,7 @@ class HttpTransactionCache(object):
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
- return request.path + "/" + token
+ return request.path.decode('utf8') + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index a14f0c807e..b5a6d6aebf 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -84,7 +84,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except Exception:
raise SynapseError(400, "Unable to parse state")
- yield self.presence_handler.set_state(user, state)
+ if self.hs.config.use_presence:
+ yield self.presence_handler.set_state(user, state)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index fcc1091760..976d98387d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -531,7 +531,7 @@ class RoomEventServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
event = yield self.event_handler.get_event(requester.user, room_id, event_id)
time_now = self.clock.time_msec()
diff --git a/synapse/rest/client/v1_only/register.py b/synapse/rest/client/v1_only/register.py
index 3439c3c6d4..5e99cffbcb 100644
--- a/synapse/rest/client/v1_only/register.py
+++ b/synapse/rest/client/v1_only/register.py
@@ -129,12 +129,9 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE
- is_using_shared_secret = login_type == LoginType.SHARED_SECRET
-
can_register = (
self.enable_registration
or is_application_server
- or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
@@ -144,7 +141,6 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service,
- LoginType.SHARED_SECRET: self._do_shared_secret,
}
session_info = self._get_session_info(request, session)
@@ -325,56 +321,6 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
})
- @defer.inlineCallbacks
- def _do_shared_secret(self, request, register_json, session):
- assert_params_in_dict(register_json, ["mac", "user", "password"])
-
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
-
- user = register_json["user"].encode("utf-8")
- password = register_json["password"].encode("utf-8")
- admin = register_json.get("admin", None)
-
- # Its important to check as we use null bytes as HMAC field separators
- if b"\x00" in user:
- raise SynapseError(400, "Invalid user")
- if b"\x00" in password:
- raise SynapseError(400, "Invalid password")
-
- # str() because otherwise hmac complains that 'unicode' does not
- # have the buffer interface
- got_mac = str(register_json["mac"])
-
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- digestmod=sha1,
- )
- want_mac.update(user)
- want_mac.update(b"\x00")
- want_mac.update(password)
- want_mac.update(b"\x00")
- want_mac.update(b"admin" if admin else b"notadmin")
- want_mac = want_mac.hexdigest()
-
- if compare_digest(want_mac, got_mac):
- handler = self.handlers.registration_handler
- user_id, token = yield handler.register(
- localpart=user.lower(),
- password=password,
- admin=bool(admin),
- )
- self._remove_session(session)
- defer.returnValue({
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- })
- else:
- raise SynapseError(
- 403, "HMAC incorrect",
- )
-
class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 147ff7d79b..7362e1858d 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -140,7 +140,7 @@ class ConsentResource(Resource):
version = parse_string(request, "v",
default=self._default_consent_version)
username = parse_string(request, "u", required=True)
- userhmac = parse_string(request, "h", required=True)
+ userhmac = parse_string(request, "h", required=True, encoding=None)
self._check_hash(username, userhmac)
@@ -175,7 +175,7 @@ class ConsentResource(Resource):
"""
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
- userhmac = parse_string(request, "h", required=True)
+ userhmac = parse_string(request, "h", required=True, encoding=None)
self._check_hash(username, userhmac)
@@ -210,9 +210,18 @@ class ConsentResource(Resource):
finish_request(request)
def _check_hash(self, userid, userhmac):
+ """
+ Args:
+ userid (unicode):
+ userhmac (bytes):
+
+ Raises:
+ SynapseError if the hash doesn't match
+
+ """
want_mac = hmac.new(
key=self._hmac_secret,
- msg=userid,
+ msg=userid.encode('utf-8'),
digestmod=sha256,
).hexdigest()
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 9b22d204a6..c1240e1963 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -55,7 +55,7 @@ class UploadResource(Resource):
requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader("Content-Length")
+ content_length = request.getHeader(b"Content-Length").decode('ascii')
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
@@ -66,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
- upload_name = parse_string(request, "filename")
+ upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:
try:
- upload_name = upload_name.decode('UTF-8')
+ upload_name = upload_name.decode('utf8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
@@ -78,8 +78,8 @@ class UploadResource(Resource):
headers = request.requestHeaders
- if headers.hasHeader("Content-Type"):
- media_type = headers.getRawHeaders(b"Content-Type")[0]
+ if headers.hasHeader(b"Content-Type"):
+ media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
diff --git a/synapse/secrets.py b/synapse/secrets.py
index f05e9ea535..f6280f951c 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -38,4 +38,4 @@ else:
return os.urandom(nbytes)
def token_hex(self, nbytes=32):
- return binascii.hexlify(self.token_bytes(nbytes))
+ return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')
diff --git a/synapse/server.py b/synapse/server.py
index 26228d8c72..a6fbc6ec0c 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -56,7 +56,7 @@ from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
from synapse.handlers.presence import PresenceHandler
-from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.room import RoomContextHandler, RoomCreationHandler
@@ -308,7 +308,10 @@ class HomeServer(object):
return InitialSyncHandler(self)
def build_profile_handler(self):
- return ProfileHandler(self)
+ if self.config.worker_app:
+ return BaseProfileHandler(self)
+ else:
+ return MasterProfileHandler(self)
def build_event_creation_handler(self):
return EventCreationHandler(self)
diff --git a/synapse/state.py b/synapse/state/__init__.py
index 8b92d4057a..b34970e4d1 100644
--- a/synapse/state.py
+++ b/synapse/state/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,23 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import hashlib
import logging
from collections import namedtuple
-from six import iteritems, iterkeys, itervalues
+from six import iteritems, itervalues
from frozendict import frozendict
from twisted.internet import defer
-from synapse import event_auth
-from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError
+from synapse.api.constants import EventTypes, RoomVersions
from synapse.events.snapshot import EventContext
+from synapse.state import v1
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
@@ -40,7 +38,7 @@ logger = logging.getLogger(__name__)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
+SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache")
EVICTION_TIMEOUT_SECONDS = 60 * 60
@@ -264,6 +262,7 @@ class StateHandler(object):
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
+
entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events],
)
@@ -338,8 +337,11 @@ class StateHandler(object):
event, resolves conflicts between them and returns them.
Args:
- room_id (str):
- event_ids (list[str]):
+ room_id (str)
+ event_ids (list[str])
+ explicit_room_version (str|None): If set uses the the given room
+ version to choose the resolution algorithm. If None, then
+ checks the database for room version.
Returns:
Deferred[_StateCacheEntry]: resolved state
@@ -353,7 +355,12 @@ class StateHandler(object):
room_id, event_ids
)
- if len(state_groups_ids) == 1:
+ if len(state_groups_ids) == 0:
+ defer.returnValue(_StateCacheEntry(
+ state={},
+ state_group=None,
+ ))
+ elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -365,8 +372,11 @@ class StateHandler(object):
delta_ids=delta_ids,
))
+ room_version = yield self.store.get_room_version(room_id)
+
result = yield self._state_resolution_handler.resolve_state_groups(
- room_id, state_groups_ids, None, self._state_map_factory,
+ room_id, room_version, state_groups_ids, None,
+ self._state_map_factory,
)
defer.returnValue(result)
@@ -375,7 +385,7 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False,
)
- def resolve_events(self, state_sets, event):
+ def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -391,7 +401,9 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events_with_state_map(state_set_ids, state_map)
+ new_state = resolve_events_with_state_map(
+ room_version, state_set_ids, state_map,
+ )
new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
@@ -430,7 +442,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
- self, room_id, state_groups_ids, event_map, state_map_factory,
+ self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
):
"""Resolves conflicts between a set of state groups
@@ -439,6 +451,7 @@ class StateResolutionHandler(object):
Args:
room_id (str): room we are resolving for (used for logging)
+ room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
@@ -492,6 +505,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
+ room_version,
list(itervalues(state_groups_ids)),
event_map=event_map,
state_map_factory=state_map_factory,
@@ -575,16 +589,10 @@ def _make_state_cache_entry(
)
-def _ordered_events(events):
- def key_func(e):
- return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
-
- return sorted(events, key=key_func)
-
-
-def resolve_events_with_state_map(state_sets, state_map):
+def resolve_events_with_state_map(room_version, state_sets, state_map):
"""
Args:
+ room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
@@ -594,75 +602,23 @@ def resolve_events_with_state_map(state_sets, state_map):
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
- if len(state_sets) == 1:
- return state_sets[0]
-
- unconflicted_state, conflicted_state = _seperate(
- state_sets,
- )
-
- auth_events = _create_auth_events_from_maps(
- unconflicted_state, conflicted_state, state_map
- )
-
- return _resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
- )
-
-
-def _seperate(state_sets):
- """Takes the state_sets and figures out which keys are conflicted and
- which aren't. i.e., which have multiple different event_ids associated
- with them in different state sets.
-
- Args:
- state_sets(iterable[dict[(str, str), str]]):
- List of dicts of (type, state_key) -> event_id, which are the
- different state groups to resolve.
-
- Returns:
- (dict[(str, str), str], dict[(str, str), set[str]]):
- A tuple of (unconflicted_state, conflicted_state), where:
-
- unconflicted_state is a dict mapping (type, state_key)->event_id
- for unconflicted state keys.
-
- conflicted_state is a dict mapping (type, state_key) to a set of
- event ids for conflicted state keys.
- """
- state_set_iterator = iter(state_sets)
- unconflicted_state = dict(next(state_set_iterator))
- conflicted_state = {}
-
- for state_set in state_set_iterator:
- for key, value in iteritems(state_set):
- # Check if there is an unconflicted entry for the state key.
- unconflicted_value = unconflicted_state.get(key)
- if unconflicted_value is None:
- # There isn't an unconflicted entry so check if there is a
- # conflicted entry.
- ls = conflicted_state.get(key)
- if ls is None:
- # There wasn't a conflicted entry so haven't seen this key before.
- # Therefore it isn't conflicted yet.
- unconflicted_state[key] = value
- else:
- # This key is already conflicted, add our value to the conflict set.
- ls.add(value)
- elif unconflicted_value != value:
- # If the unconflicted value is not the same as our value then we
- # have a new conflict. So move the key from the unconflicted_state
- # to the conflicted state.
- conflicted_state[key] = {value, unconflicted_value}
- unconflicted_state.pop(key, None)
-
- return unconflicted_state, conflicted_state
+ if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
+ return v1.resolve_events_with_state_map(
+ state_sets, state_map,
+ )
+ else:
+ # This should only happen if we added a version but forgot to add it to
+ # the list above.
+ raise Exception(
+ "No state resolution algorithm defined for version %r" % (room_version,)
+ )
-@defer.inlineCallbacks
-def resolve_events_with_factory(state_sets, event_map, state_map_factory):
+def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
"""
Args:
+ room_version(str): Version of the room
+
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
@@ -682,185 +638,13 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
- if len(state_sets) == 1:
- defer.returnValue(state_sets[0])
-
- unconflicted_state, conflicted_state = _seperate(
- state_sets,
- )
-
- needed_events = set(
- event_id
- for event_ids in itervalues(conflicted_state)
- for event_id in event_ids
- )
- if event_map is not None:
- needed_events -= set(iterkeys(event_map))
-
- logger.info("Asking for %d conflicted events", len(needed_events))
-
- # dict[str, FrozenEvent]: a map from state event id to event. Only includes
- # the state events which are in conflict (and those in event_map)
- state_map = yield state_map_factory(needed_events)
- if event_map is not None:
- state_map.update(event_map)
-
- # get the ids of the auth events which allow us to authenticate the
- # conflicted state, picking only from the unconflicting state.
- #
- # dict[(str, str), str]: a map from state key to event id
- auth_events = _create_auth_events_from_maps(
- unconflicted_state, conflicted_state, state_map
- )
-
- new_needed_events = set(itervalues(auth_events))
- new_needed_events -= needed_events
- if event_map is not None:
- new_needed_events -= set(iterkeys(event_map))
-
- logger.info("Asking for %d auth events", len(new_needed_events))
-
- state_map_new = yield state_map_factory(new_needed_events)
- state_map.update(state_map_new)
-
- defer.returnValue(_resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
- ))
-
-
-def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
- auth_events = {}
- for event_ids in itervalues(conflicted_state):
- for event_id in event_ids:
- if event_id in state_map:
- keys = event_auth.auth_types_for_event(state_map[event_id])
- for key in keys:
- if key not in auth_events:
- event_id = unconflicted_state.get(key, None)
- if event_id:
- auth_events[key] = event_id
- return auth_events
-
-
-def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
- state_map):
- conflicted_state = {}
- for key, event_ids in iteritems(conflicted_state_ids):
- events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
- if len(events) > 1:
- conflicted_state[key] = events
- elif len(events) == 1:
- unconflicted_state_ids[key] = events[0].event_id
-
- auth_events = {
- key: state_map[ev_id]
- for key, ev_id in iteritems(auth_event_ids)
- if ev_id in state_map
- }
-
- try:
- resolved_state = _resolve_state_events(
- conflicted_state, auth_events
+ if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
+ return v1.resolve_events_with_factory(
+ state_sets, event_map, state_map_factory,
+ )
+ else:
+ # This should only happen if we added a version but forgot to add it to
+ # the list above.
+ raise Exception(
+ "No state resolution algorithm defined for version %r" % (room_version,)
)
- except Exception:
- logger.exception("Failed to resolve state")
- raise
-
- new_state = unconflicted_state_ids
- for key, event in iteritems(resolved_state):
- new_state[key] = event.event_id
-
- return new_state
-
-
-def _resolve_state_events(conflicted_state, auth_events):
- """ This is where we actually decide which of the conflicted state to
- use.
-
- We resolve conflicts in the following order:
- 1. power levels
- 2. join rules
- 3. memberships
- 4. other events.
- """
- resolved_state = {}
- if POWER_KEY in conflicted_state:
- events = conflicted_state[POWER_KEY]
- logger.debug("Resolving conflicted power levels %r", events)
- resolved_state[POWER_KEY] = _resolve_auth_events(
- events, auth_events)
-
- auth_events.update(resolved_state)
-
- for key, events in iteritems(conflicted_state):
- if key[0] == EventTypes.JoinRules:
- logger.debug("Resolving conflicted join rules %r", events)
- resolved_state[key] = _resolve_auth_events(
- events,
- auth_events
- )
-
- auth_events.update(resolved_state)
-
- for key, events in iteritems(conflicted_state):
- if key[0] == EventTypes.Member:
- logger.debug("Resolving conflicted member lists %r", events)
- resolved_state[key] = _resolve_auth_events(
- events,
- auth_events
- )
-
- auth_events.update(resolved_state)
-
- for key, events in iteritems(conflicted_state):
- if key not in resolved_state:
- logger.debug("Resolving conflicted state %r:%r", key, events)
- resolved_state[key] = _resolve_normal_events(
- events, auth_events
- )
-
- return resolved_state
-
-
-def _resolve_auth_events(events, auth_events):
- reverse = [i for i in reversed(_ordered_events(events))]
-
- auth_keys = set(
- key
- for event in events
- for key in event_auth.auth_types_for_event(event)
- )
-
- new_auth_events = {}
- for key in auth_keys:
- auth_event = auth_events.get(key, None)
- if auth_event:
- new_auth_events[key] = auth_event
-
- auth_events = new_auth_events
-
- prev_event = reverse[0]
- for event in reverse[1:]:
- auth_events[(prev_event.type, prev_event.state_key)] = prev_event
- try:
- # The signatures have already been checked at this point
- event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
- prev_event = event
- except AuthError:
- return prev_event
-
- return event
-
-
-def _resolve_normal_events(events, auth_events):
- for event in _ordered_events(events):
- try:
- # The signatures have already been checked at this point
- event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
- return event
- except AuthError:
- pass
-
- # Use the last event (the one with the least depth) if they all fail
- # the auth check.
- return event
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
new file mode 100644
index 0000000000..3a1f7054a1
--- /dev/null
+++ b/synapse/state/v1.py
@@ -0,0 +1,321 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+import logging
+
+from six import iteritems, iterkeys, itervalues
+
+from twisted.internet import defer
+
+from synapse import event_auth
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
+
+logger = logging.getLogger(__name__)
+
+
+POWER_KEY = (EventTypes.PowerLevels, "")
+
+
+def resolve_events_with_state_map(state_sets, state_map):
+ """
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+ state_map(dict): a dict from event_id to event, for all events in
+ state_sets.
+
+ Returns
+ dict[(str, str), str]:
+ a map from (type, state_key) to event_id.
+ """
+ if len(state_sets) == 1:
+ return state_sets[0]
+
+ unconflicted_state, conflicted_state = _seperate(
+ state_sets,
+ )
+
+ auth_events = _create_auth_events_from_maps(
+ unconflicted_state, conflicted_state, state_map
+ )
+
+ return _resolve_with_state(
+ unconflicted_state, conflicted_state, auth_events, state_map
+ )
+
+
+@defer.inlineCallbacks
+def resolve_events_with_factory(state_sets, event_map, state_map_factory):
+ """
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
+ state_map_factory(func): will be called
+ with a list of event_ids that are needed, and should return with
+ a Deferred of dict of event_id to event.
+
+ Returns
+ Deferred[dict[(str, str), str]]:
+ a map from (type, state_key) to event_id.
+ """
+ if len(state_sets) == 1:
+ defer.returnValue(state_sets[0])
+
+ unconflicted_state, conflicted_state = _seperate(
+ state_sets,
+ )
+
+ needed_events = set(
+ event_id
+ for event_ids in itervalues(conflicted_state)
+ for event_id in event_ids
+ )
+ if event_map is not None:
+ needed_events -= set(iterkeys(event_map))
+
+ logger.info("Asking for %d conflicted events", len(needed_events))
+
+ # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+ # the state events which are in conflict (and those in event_map)
+ state_map = yield state_map_factory(needed_events)
+ if event_map is not None:
+ state_map.update(event_map)
+
+ # get the ids of the auth events which allow us to authenticate the
+ # conflicted state, picking only from the unconflicting state.
+ #
+ # dict[(str, str), str]: a map from state key to event id
+ auth_events = _create_auth_events_from_maps(
+ unconflicted_state, conflicted_state, state_map
+ )
+
+ new_needed_events = set(itervalues(auth_events))
+ new_needed_events -= needed_events
+ if event_map is not None:
+ new_needed_events -= set(iterkeys(event_map))
+
+ logger.info("Asking for %d auth events", len(new_needed_events))
+
+ state_map_new = yield state_map_factory(new_needed_events)
+ state_map.update(state_map_new)
+
+ defer.returnValue(_resolve_with_state(
+ unconflicted_state, conflicted_state, auth_events, state_map
+ ))
+
+
+def _seperate(state_sets):
+ """Takes the state_sets and figures out which keys are conflicted and
+ which aren't. i.e., which have multiple different event_ids associated
+ with them in different state sets.
+
+ Args:
+ state_sets(iterable[dict[(str, str), str]]):
+ List of dicts of (type, state_key) -> event_id, which are the
+ different state groups to resolve.
+
+ Returns:
+ (dict[(str, str), str], dict[(str, str), set[str]]):
+ A tuple of (unconflicted_state, conflicted_state), where:
+
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
+
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
+ """
+ state_set_iterator = iter(state_sets)
+ unconflicted_state = dict(next(state_set_iterator))
+ conflicted_state = {}
+
+ for state_set in state_set_iterator:
+ for key, value in iteritems(state_set):
+ # Check if there is an unconflicted entry for the state key.
+ unconflicted_value = unconflicted_state.get(key)
+ if unconflicted_value is None:
+ # There isn't an unconflicted entry so check if there is a
+ # conflicted entry.
+ ls = conflicted_state.get(key)
+ if ls is None:
+ # There wasn't a conflicted entry so haven't seen this key before.
+ # Therefore it isn't conflicted yet.
+ unconflicted_state[key] = value
+ else:
+ # This key is already conflicted, add our value to the conflict set.
+ ls.add(value)
+ elif unconflicted_value != value:
+ # If the unconflicted value is not the same as our value then we
+ # have a new conflict. So move the key from the unconflicted_state
+ # to the conflicted state.
+ conflicted_state[key] = {value, unconflicted_value}
+ unconflicted_state.pop(key, None)
+
+ return unconflicted_state, conflicted_state
+
+
+def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+ auth_events = {}
+ for event_ids in itervalues(conflicted_state):
+ for event_id in event_ids:
+ if event_id in state_map:
+ keys = event_auth.auth_types_for_event(state_map[event_id])
+ for key in keys:
+ if key not in auth_events:
+ event_id = unconflicted_state.get(key, None)
+ if event_id:
+ auth_events[key] = event_id
+ return auth_events
+
+
+def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
+ state_map):
+ conflicted_state = {}
+ for key, event_ids in iteritems(conflicted_state_ids):
+ events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
+ if len(events) > 1:
+ conflicted_state[key] = events
+ elif len(events) == 1:
+ unconflicted_state_ids[key] = events[0].event_id
+
+ auth_events = {
+ key: state_map[ev_id]
+ for key, ev_id in iteritems(auth_event_ids)
+ if ev_id in state_map
+ }
+
+ try:
+ resolved_state = _resolve_state_events(
+ conflicted_state, auth_events
+ )
+ except Exception:
+ logger.exception("Failed to resolve state")
+ raise
+
+ new_state = unconflicted_state_ids
+ for key, event in iteritems(resolved_state):
+ new_state[key] = event.event_id
+
+ return new_state
+
+
+def _resolve_state_events(conflicted_state, auth_events):
+ """ This is where we actually decide which of the conflicted state to
+ use.
+
+ We resolve conflicts in the following order:
+ 1. power levels
+ 2. join rules
+ 3. memberships
+ 4. other events.
+ """
+ resolved_state = {}
+ if POWER_KEY in conflicted_state:
+ events = conflicted_state[POWER_KEY]
+ logger.debug("Resolving conflicted power levels %r", events)
+ resolved_state[POWER_KEY] = _resolve_auth_events(
+ events, auth_events)
+
+ auth_events.update(resolved_state)
+
+ for key, events in iteritems(conflicted_state):
+ if key[0] == EventTypes.JoinRules:
+ logger.debug("Resolving conflicted join rules %r", events)
+ resolved_state[key] = _resolve_auth_events(
+ events,
+ auth_events
+ )
+
+ auth_events.update(resolved_state)
+
+ for key, events in iteritems(conflicted_state):
+ if key[0] == EventTypes.Member:
+ logger.debug("Resolving conflicted member lists %r", events)
+ resolved_state[key] = _resolve_auth_events(
+ events,
+ auth_events
+ )
+
+ auth_events.update(resolved_state)
+
+ for key, events in iteritems(conflicted_state):
+ if key not in resolved_state:
+ logger.debug("Resolving conflicted state %r:%r", key, events)
+ resolved_state[key] = _resolve_normal_events(
+ events, auth_events
+ )
+
+ return resolved_state
+
+
+def _resolve_auth_events(events, auth_events):
+ reverse = [i for i in reversed(_ordered_events(events))]
+
+ auth_keys = set(
+ key
+ for event in events
+ for key in event_auth.auth_types_for_event(event)
+ )
+
+ new_auth_events = {}
+ for key in auth_keys:
+ auth_event = auth_events.get(key, None)
+ if auth_event:
+ new_auth_events[key] = auth_event
+
+ auth_events = new_auth_events
+
+ prev_event = reverse[0]
+ for event in reverse[1:]:
+ auth_events[(prev_event.type, prev_event.state_key)] = prev_event
+ try:
+ # The signatures have already been checked at this point
+ event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+ prev_event = event
+ except AuthError:
+ return prev_event
+
+ return event
+
+
+def _resolve_normal_events(events, auth_events):
+ for event in _ordered_events(events):
+ try:
+ # The signatures have already been checked at this point
+ event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+ return event
+ except AuthError:
+ pass
+
+ # Use the last event (the one with the least depth) if they all fail
+ # the auth check.
+ return event
+
+
+def _ordered_events(events):
+ def key_func(e):
+ return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
+
+ return sorted(events, key=key_func)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 025a7fb6d9..f39c8c8461 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
}
events_map = {ev.event_id: ev for ev, _ in events_context}
+ room_version = yield self.get_room_version(room_id)
+
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
- room_id, state_groups, events_map, get_events
+ room_id, room_version, state_groups, events_map, get_events
)
defer.returnValue((res.state, None))
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 60295da254..88b50f33b5 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
-
-class ProfileStore(ProfileWorkerStore):
def create_profile(self, user_localpart):
return self._simple_insert(
table="profiles",
@@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore):
desc="set_profile_avatar_url",
)
+
+class ProfileStore(ProfileWorkerStore):
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 3378fc77d1..61013b8919 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @cachedInlineCallbacks(max_entries=10000)
+ def get_ratelimit_for_user(self, user_id):
+ """Check if there are any overrides for ratelimiting for the given
+ user
+
+ Args:
+ user_id (str)
+
+ Returns:
+ RatelimitOverride if there is an override, else None. If the contents
+ of RatelimitOverride are None or 0 then ratelimitng has been
+ disabled for that user entirely.
+ """
+ row = yield self._simple_select_one(
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ retcols=("messages_per_second", "burst_count"),
+ allow_none=True,
+ desc="get_ratelimit_for_user",
+ )
+
+ if row:
+ defer.returnValue(RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
+ ))
+ else:
+ defer.returnValue(None)
+
class RoomStore(RoomWorkerStore, SearchStore):
@@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
"get_all_new_public_rooms", get_all_new_public_rooms
)
- @cachedInlineCallbacks(max_entries=10000)
- def get_ratelimit_for_user(self, user_id):
- """Check if there are any overrides for ratelimiting for the given
- user
-
- Args:
- user_id (str)
-
- Returns:
- RatelimitOverride if there is an override, else None. If the contents
- of RatelimitOverride are None or 0 then ratelimitng has been
- disabled for that user entirely.
- """
- row = yield self._simple_select_one(
- table="ratelimit_override",
- keyvalues={"user_id": user_id},
- retcols=("messages_per_second", "burst_count"),
- allow_none=True,
- desc="get_ratelimit_for_user",
- )
-
- if row:
- defer.returnValue(RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- ))
- else:
- defer.returnValue(None)
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dd03c4168b..4b971efdba 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def __init__(self, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+ # Originally the state store used a single DictionaryCache to cache the
+ # event IDs for the state types in a given state group to avoid hammering
+ # on the state_group* tables.
+ #
+ # The point of using a DictionaryCache is that it can cache a subset
+ # of the state events for a given state group (i.e. a subset of the keys for a
+ # given dict which is an entry in the cache for a given state group ID).
+ #
+ # However, this poses problems when performing complicated queries
+ # on the store - for instance: "give me all the state for this group, but
+ # limit members to this subset of users", as DictionaryCache's API isn't
+ # rich enough to say "please cache any of these fields, apart from this subset".
+ # This is problematic when lazy loading members, which requires this behaviour,
+ # as without it the cache has no choice but to speculatively load all
+ # state events for the group, which negates the efficiency being sought.
+ #
+ # Rather than overcomplicating DictionaryCache's API, we instead split the
+ # state_group_cache into two halves - one for tracking non-member events,
+ # and the other for tracking member_events. This means that lazy loading
+ # queries can be made in a cache-friendly manner by querying both caches
+ # separately and then merging the result. So for the example above, you
+ # would query the members cache for a specific subset of state keys
+ # (which DictionaryCache will handle efficiently and fine) and the non-members
+ # cache for all state (which DictionaryCache will similarly handle fine)
+ # and then just merge the results together.
+ #
+ # We size the non-members cache to be smaller than the members cache as the
+ # vast majority of state in Matrix (today) is member events.
+
self._state_group_cache = DictionaryCache(
- "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
+ "*stateGroupCache*",
+ # TODO: this hasn't been tuned yet
+ 50000 * get_cache_factor_for("stateGroupCache")
+ )
+ self._state_group_members_cache = DictionaryCache(
+ "*stateGroupMembersCache*",
+ 500000 * get_cache_factor_for("stateGroupMembersCache")
)
@defer.inlineCallbacks
@@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
})
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, types):
+ def _get_state_groups_from_groups(self, groups, types, members=None):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
@@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
types (Iterable[str, str|None]|None): list of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. If None, all types are returned.
+ members (bool|None): If not None, then, in addition to any filtering
+ implied by types, the results are also filtered to only include
+ member events (if True), or to exclude member events (if False)
Returns:
dictionary state_group -> (dict of (type, state_key) -> event id)
@@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn, chunk, types,
+ self._get_state_groups_from_groups_txn, chunk, types, members,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
- self, txn, groups, types=None,
+ self, txn, groups, types=None, members=None,
):
results = {group: {} for group in groups}
@@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
%s
""")
+ if members is True:
+ sql += " AND type = '%s'" % (EventTypes.Member,)
+ elif members is False:
+ sql += " AND type <> '%s'" % (EventTypes.Member,)
+
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
@@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
where_clause = ""
+ if members is True:
+ where_clause += " AND type = '%s'" % EventTypes.Member
+ elif members is False:
+ where_clause += " AND type <> '%s'" % EventTypes.Member
+
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
@@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
- def _get_some_state_from_cache(self, group, types, filtered_types=None):
+ def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
+ cache(DictionaryCache): the state group cache to use
group(int): The state group to lookup
types(list[str, str|None]): List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all
@@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
- is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
+ is_all, known_absent, state_dict_ids = cache.get(group)
type_to_key = {}
- # tracks whether any of ourrequested types are missing from the cache
+ # tracks whether any of our requested types are missing from the cache
missing_types = False
for typ, state_key in types:
@@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if include(k[0], k[1])
}, got_all
- def _get_all_state_from_cache(self, group):
+ def _get_all_state_from_cache(self, cache, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cache, if False we need to query the DB for the missing state.
Args:
+ cache(DictionaryCache): the state group cache to use
group: The state group to lookup
"""
- is_all, _, state_dict_ids = self._state_group_cache.get(group)
+ is_all, _, state_dict_ids = cache.get(group)
return state_dict_ids, is_all
@@ -685,6 +735,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[dict[int, dict[(type, state_key), EventBase]]]
a dictionary mapping from state group to state dictionary.
"""
+ if types is not None:
+ non_member_types = [t for t in types if t[0] != EventTypes.Member]
+
+ if filtered_types is not None and EventTypes.Member not in filtered_types:
+ # we want all of the membership events
+ member_types = None
+ else:
+ member_types = [t for t in types if t[0] == EventTypes.Member]
+
+ else:
+ non_member_types = None
+ member_types = None
+
+ non_member_state = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, non_member_types, filtered_types,
+ )
+ # XXX: we could skip this entirely if member_types is []
+ member_state = yield self._get_state_for_groups_using_cache(
+ # we set filtered_types=None as member_state only ever contain members.
+ groups, self._state_group_members_cache, member_types, None,
+ )
+
+ state = non_member_state
+ for group in groups:
+ state[group].update(member_state[group])
+
+ defer.returnValue(state)
+
+ @defer.inlineCallbacks
+ def _get_state_for_groups_using_cache(
+ self, groups, cache, types=None, filtered_types=None
+ ):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key, querying from a specific cache.
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ cache (DictionaryCache): the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the specific
+ members state cache.
+ types (None|iterable[(str, None|str)]):
+ indicates the state type/keys required. If None, the whole
+ state is fetched and returned.
+
+ Otherwise, each entry should be a `(type, state_key)` tuple to
+ include in the response. A `state_key` of None is a wildcard
+ meaning that we require all state with that type.
+ filtered_types(list[str]|None): Only apply filtering via `types` to this
+ list of event types. Other types of events are returned unfiltered.
+ If None, `types` filtering is applied to all events.
+
+ Returns:
+ Deferred[dict[int, dict[(type, state_key), EventBase]]]
+ a dictionary mapping from state group to state dictionary.
+ """
if types:
types = frozenset(types)
results = {}
@@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if types is not None:
for group in set(groups):
state_dict_ids, got_all = self._get_some_state_from_cache(
- group, types, filtered_types
+ cache, group, types, filtered_types
)
results[group] = state_dict_ids
@@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
for group in set(groups):
state_dict_ids, got_all = self._get_all_state_from_cache(
- group
+ cache, group
)
results[group] = state_dict_ids
@@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
missing_groups.append(group)
if missing_groups:
- # Okay, so we have some missing_types, lets fetch them.
- cache_seq_num = self._state_group_cache.sequence
+ # Okay, so we have some missing_types, let's fetch them.
+ cache_seq_num = cache.sequence
# the DictionaryCache knows if it has *all* the state, but
# does not know if it has all of the keys of a particular type,
@@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types_to_fetch
+ missing_groups, types_to_fetch, cache == self._state_group_members_cache,
)
for group, group_state_dict in iteritems(group_to_state_dict):
@@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# update the cache with all the things we fetched from the
# database.
- self._state_group_cache.update(
+ cache.update(
cache_seq_num,
key=group,
value=group_state_dict,
@@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
],
)
- # Prefill the state group cache with this group.
+ # Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
+
+ current_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] == EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_members_cache.update,
+ self._state_group_members_cache.sequence,
+ key=state_group,
+ value=dict(current_member_state_ids),
+ )
+
+ current_non_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] != EventTypes.Member
+ }
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
- value=dict(current_state_ids),
+ value=dict(current_non_member_state_ids),
)
return state_group
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 07e83fadda..a0c2d37610 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -385,7 +385,13 @@ class LoggingContextFilter(logging.Filter):
context = LoggingContext.current_context()
for key, value in self.defaults.items():
setattr(record, key, value)
- context.copy_to(record)
+
+ # context should never be None, but if it somehow ends up being, then
+ # we end up in a death spiral of infinite loops, so let's check, for
+ # robustness' sake.
+ if context is not None:
+ context.copy_to(record)
+
return True
@@ -396,7 +402,9 @@ class PreserveLoggingContext(object):
__slots__ = ["current_context", "new_context", "has_parent"]
- def __init__(self, new_context=LoggingContext.sentinel):
+ def __init__(self, new_context=None):
+ if new_context is None:
+ new_context = LoggingContext.sentinel
self.new_context = new_context
def __enter__(self):
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 62a00189cc..ef31458226 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -20,6 +20,8 @@ import time
from functools import wraps
from inspect import getcallargs
+from six import PY3
+
_TIME_FUNC_ID = 0
@@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG):
- lineno = f.func_code.co_firstlineno
- pathname = f.func_code.co_filename
+ if PY3:
+ lineno = f.__code__.co_firstlineno
+ pathname = f.__code__.co_filename
+ else:
+ lineno = f.func_code.co_firstlineno
+ pathname = f.func_code.co_filename
record = logging.LogRecord(
name=name,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 43d9db67ec..6f318c6a29 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -16,6 +16,7 @@
import random
import string
+from six import PY3
from six.moves import range
_string_with_symbols = (
@@ -34,6 +35,17 @@ def random_string_with_symbols(length):
def is_ascii(s):
+
+ if PY3:
+ if isinstance(s, bytes):
+ try:
+ s.decode('ascii').encode('ascii')
+ except UnicodeDecodeError:
+ return False
+ except UnicodeEncodeError:
+ return False
+ return True
+
try:
s.encode("ascii")
except UnicodeEncodeError:
@@ -49,6 +61,9 @@ def to_ascii(s):
If given None then will return None.
"""
+ if PY3:
+ return s
+
if s is None:
return None
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 1fbcd41115..3baba3225a 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -30,7 +30,7 @@ def get_version_string(module):
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
@@ -40,7 +40,7 @@ def get_version_string(module):
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
@@ -50,7 +50,7 @@ def get_version_string(module):
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
- ).strip()
+ ).strip().decode('ascii')
except subprocess.CalledProcessError:
git_commit = ""
@@ -60,7 +60,7 @@ def get_version_string(module):
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
- ).strip().endswith(dirty_string)
+ ).strip().decode('ascii').endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
@@ -77,8 +77,8 @@ def get_version_string(module):
"%s (%s)" % (
module.__version__, git_version,
)
- ).encode("ascii")
+ )
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__.encode("ascii")
+ return module.__version__
diff --git a/tests/app/__init__.py b/tests/app/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/app/__init__.py
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
new file mode 100644
index 0000000000..76b5090fff
--- /dev/null
+++ b/tests/app/test_frontend_proxy.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.app.frontend_proxy import FrontendProxyServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class FrontendProxyTests(HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+
+ hs = self.setup_test_homeserver(
+ http_client=None, homeserverToUse=FrontendProxyServer
+ )
+
+ return hs
+
+ def test_listen_http_with_presence_enabled(self):
+ """
+ When presence is on, the stub servlet will not register.
+ """
+ # Presence is on
+ self.hs.config.use_presence = True
+
+ config = {
+ "port": 8080,
+ "bind_addresses": ["0.0.0.0"],
+ "resources": [{"names": ["client"]}],
+ }
+
+ # Listen with the config
+ self.hs._listen_http(config)
+
+ # Grab the resource from the site that was told to listen
+ self.assertEqual(len(self.reactor.tcpServers), 1)
+ site = self.reactor.tcpServers[0][1]
+ self.resource = (
+ site.resource.children["_matrix"].children["client"].children["r0"]
+ )
+
+ request, channel = self.make_request("PUT", "presence/a/status")
+ self.render(request)
+
+ # 400 + unrecognised, because nothing is registered
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+ def test_listen_http_with_presence_disabled(self):
+ """
+ When presence is on, the stub servlet will register.
+ """
+ # Presence is off
+ self.hs.config.use_presence = False
+
+ config = {
+ "port": 8080,
+ "bind_addresses": ["0.0.0.0"],
+ "resources": [{"names": ["client"]}],
+ }
+
+ # Listen with the config
+ self.hs._listen_http(config)
+
+ # Grab the resource from the site that was told to listen
+ self.assertEqual(len(self.reactor.tcpServers), 1)
+ site = self.reactor.tcpServers[0][1]
+ self.resource = (
+ site.resource.children["_matrix"].children["client"].children["r0"]
+ )
+
+ request, channel = self.make_request("PUT", "presence/a/status")
+ self.render(request)
+
+ # 401, because the stub servlet still checks authentication
+ self.assertEqual(channel.code, 401)
+ self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 62dc69003c..80da1c8954 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -20,7 +20,7 @@ from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError
-from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
@@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver
class ProfileHandlers(object):
def __init__(self, hs):
- self.profile_handler = ProfileHandler(hs)
+ self.profile_handler = MasterProfileHandler(hs)
class ProfileTestCase(unittest.TestCase):
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 622be2eef8..2ba80ccdcf 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -112,6 +112,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks
def test_invites(self):
+ yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="invite"
@@ -133,7 +134,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks
def test_push_actions_for_user(self):
- yield self.persist(type="m.room.create", creator=USER_ID)
+ yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.join", key=USER_ID, membership="join")
yield self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
new file mode 100644
index 0000000000..66c2b68707
--- /dev/null
+++ b/tests/rest/client/v1/test_presence.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from synapse.rest.client.v1 import presence
+from synapse.types import UserID
+
+from tests import unittest
+
+
+class PresenceTestCase(unittest.HomeserverTestCase):
+ """ Tests presence REST API. """
+
+ user_id = "@sid:red"
+
+ user = UserID.from_string(user_id)
+ servlets = [presence.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+
+ hs = self.setup_test_homeserver(
+ "red", http_client=None, federation_client=Mock()
+ )
+
+ hs.presence_handler = Mock()
+
+ return hs
+
+ def test_put_presence(self):
+ """
+ PUT to the status endpoint with use_presence enabled will call
+ set_state on the presence handler.
+ """
+ self.hs.config.use_presence = True
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ request, channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
+
+ def test_put_presence_disabled(self):
+ """
+ PUT to the status endpoint with use_presence disbled will NOT call
+ set_state on the presence handler.
+ """
+ self.hs.config.use_presence = False
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ request, channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 4be88b8a39..6b7ff813d5 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -25,7 +25,7 @@ from synapse.rest.client.v1_only.register import register_servlets
from synapse.util import Clock
from tests import unittest
-from tests.server import make_request, setup_test_homeserver
+from tests.server import make_request, render, setup_test_homeserver
class CreateUserServletTestCase(unittest.TestCase):
@@ -77,10 +77,7 @@ class CreateUserServletTestCase(unittest.TestCase):
)
request, channel = make_request(b"POST", url, request_data)
- request.render(res)
-
- # Advance the clock because it waits
- self.clock.advance(1)
+ render(request, res, self.clock)
self.assertEquals(channel.result["code"], b"200")
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9f862f9dfa..40dc4ea256 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -23,7 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from tests import unittest
-from tests.server import make_request, wait_until_result
+from tests.server import make_request, render
class RestTestCase(unittest.TestCase):
@@ -171,8 +171,7 @@ class RestHelper(object):
request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8')
)
- request.render(self.resource)
- wait_until_result(self.hs.get_reactor(), channel)
+ render(request, self.resource, self.hs.get_reactor())
assert channel.result["code"] == b"200", channel.result
self.auth_user_id = temp_id
@@ -220,8 +219,7 @@ class RestHelper(object):
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
- request.render(self.resource)
- wait_until_result(self.hs.get_reactor(), channel)
+ render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 8260c130f8..6a886ee3b8 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -24,8 +24,8 @@ from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
+ render,
setup_test_homeserver,
- wait_until_result,
)
PATH_PREFIX = "/_matrix/client/v2_alpha"
@@ -76,8 +76,7 @@ class FilterTestCase(unittest.TestCase):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
@@ -91,8 +90,7 @@ class FilterTestCase(unittest.TestCase):
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -105,8 +103,7 @@ class FilterTestCase(unittest.TestCase):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403")
@@ -121,8 +118,7 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
@@ -131,8 +127,7 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -143,8 +138,7 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEqual(channel.result["code"], b"400")
@@ -153,7 +147,6 @@ class FilterTestCase(unittest.TestCase):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b72bd0fb7f..1c128e81f5 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -11,7 +11,7 @@ from synapse.rest.client.v2_alpha.register import register_servlets
from synapse.util import Clock
from tests import unittest
-from tests.server import make_request, setup_test_homeserver, wait_until_result
+from tests.server import make_request, render, setup_test_homeserver
class RegisterRestServletTestCase(unittest.TestCase):
@@ -72,8 +72,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
@@ -89,16 +88,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
request, channel = make_request(b"POST", self.url, request_data)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
@@ -106,8 +103,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"})
request, channel = make_request(b"POST", self.url, request_data)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -126,8 +122,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.device_handler.check_device_registered = Mock(return_value=device_id)
request, channel = make_request(b"POST", self.url, request_data)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
det_data = {
"user_id": user_id,
@@ -149,8 +144,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
request, channel = make_request(b"POST", self.url, request_data)
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
@@ -162,8 +156,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
det_data = {
"user_id": user_id,
@@ -177,8 +170,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.config.allow_guest_access = False
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ render(request, self.resource, self.clock)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 2e1d06c509..560b1fba96 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -13,72 +13,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.types
-from synapse.http.server import JsonResource
+from mock import Mock
+
from synapse.rest.client.v2_alpha import sync
-from synapse.types import UserID
-from synapse.util import Clock
from tests import unittest
-from tests.server import (
- ThreadedMemoryReactorClock as MemoryReactorClock,
- make_request,
- setup_test_homeserver,
- wait_until_result,
-)
-
-PATH_PREFIX = "/_matrix/client/v2_alpha"
-class FilterTestCase(unittest.TestCase):
+class FilterTestCase(unittest.HomeserverTestCase):
- USER_ID = "@apple:test"
- TO_REGISTER = [sync]
+ user_id = "@apple:test"
+ servlets = [sync.register_servlets]
- def setUp(self):
- self.clock = MemoryReactorClock()
- self.hs_clock = Clock(self.clock)
+ def make_homeserver(self, reactor, clock):
- self.hs = setup_test_homeserver(
- self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
+ hs = self.setup_test_homeserver(
+ "red", http_client=None, federation_client=Mock()
)
+ return hs
- self.auth = self.hs.get_auth()
-
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.USER_ID),
- "token_id": 1,
- "is_guest": False,
- }
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return synapse.types.create_requester(
- UserID.from_string(self.USER_ID), 1, False, None
- )
-
- self.auth.get_user_by_access_token = get_user_by_access_token
- self.auth.get_user_by_req = get_user_by_req
+ def test_sync_argless(self):
+ request, channel = self.make_request("GET", "/sync")
+ self.render(request)
- self.store = self.hs.get_datastore()
- self.filtering = self.hs.get_filtering()
- self.resource = JsonResource(self.hs)
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(
+ set(
+ [
+ "next_batch",
+ "rooms",
+ "presence",
+ "account_data",
+ "to_device",
+ "device_lists",
+ ]
+ ).issubset(set(channel.json_body.keys()))
+ )
- for r in self.TO_REGISTER:
- r.register_servlets(self.hs, self.resource)
+ def test_sync_presence_disabled(self):
+ """
+ When presence is disabled, the key does not appear in /sync.
+ """
+ self.hs.config.use_presence = False
- def test_sync_argless(self):
- request, channel = make_request("GET", "/_matrix/client/r0/sync")
- request.render(self.resource)
- wait_until_result(self.clock, channel)
+ request, channel = self.make_request("GET", "/sync")
+ self.render(request)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertTrue(
set(
[
"next_batch",
"rooms",
- "presence",
"account_data",
"to_device",
"device_lists",
diff --git a/tests/server.py b/tests/server.py
index beb24cf032..c63b2c3100 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -24,6 +24,7 @@ class FakeChannel(object):
"""
result = attr.ib(default=attr.Factory(dict))
+ _producer = None
@property
def json_body(self):
@@ -49,6 +50,15 @@ class FakeChannel(object):
self.result["body"] += content
+ def registerProducer(self, producer, streaming):
+ self._producer = producer
+
+ def unregisterProducer(self):
+ if self._producer is None:
+ return
+
+ self._producer = None
+
def requestDone(self, _self):
self.result["done"] = True
@@ -111,14 +121,19 @@ def make_request(method, path, content=b""):
return req, channel
-def wait_until_result(clock, channel, timeout=100):
+def wait_until_result(clock, request, timeout=100):
"""
- Wait until the channel has a result.
+ Wait until the request is finished.
"""
clock.run()
x = 0
- while not channel.result:
+ while not request.finished:
+
+ # If there's a producer, tell it to resume producing so we get content
+ if request._channel._producer:
+ request._channel._producer.resumeProducing()
+
x += 1
if x > timeout:
@@ -129,7 +144,7 @@ def wait_until_result(clock, channel, timeout=100):
def render(request, resource, clock):
request.render(resource)
- wait_until_result(clock, request._channel)
+ wait_until_result(clock, request)
class ThreadedMemoryReactorClock(MemoryReactorClock):
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index c4e9fb72bf..02bf975fbf 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RedactionTestCase(unittest.TestCase):
@@ -41,6 +41,8 @@ class RedactionTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
+
self.depth = 1
@defer.inlineCallbacks
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index c83ef60062..978c66133d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.types import RoomID, UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
class RoomMemberStoreTestCase(unittest.TestCase):
@@ -45,6 +45,8 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
+ yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
+
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new(
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index ebfd969b36..d717b9f94e 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -185,6 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [], filtered_types=[EventTypes.Member]
)
@@ -197,8 +198,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {},
+ state_dict,
+ )
+
# test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
@@ -207,6 +220,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
@@ -216,6 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group,
[(EventTypes.Member, e5.state_key)],
filtered_types=[EventTypes.Member],
@@ -226,6 +252,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group,
+ [(EventTypes.Member, e5.state_key)],
+ filtered_types=[EventTypes.Member],
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
@@ -234,6 +274,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
@@ -254,9 +295,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
{
(e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id,
- (e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
- (e5.type, e5.state_key): e5.event_id,
},
)
@@ -269,8 +307,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
# list fetched keys so it knows it's partial
fetched_keys=(
(e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
),
)
@@ -284,8 +320,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
set(
[
(e1.type, e1.state_key),
- (e3.type, e3.state_key),
- (e5.type, e5.state_key),
]
),
)
@@ -293,8 +327,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id,
- (e3.type, e3.state_key): e3.event_id,
- (e5.type, e5.state_key): e5.event_id,
},
)
@@ -304,14 +336,25 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [], filtered_types=[EventTypes.Member]
)
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
+ room_id = self.room.to_string()
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual({}, state_dict)
+
# test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
)
@@ -319,8 +362,19 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e3.type, e3.state_key): e3.event_id,
- # e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
@@ -328,6 +382,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
group,
[(EventTypes.Member, e5.state_key)],
filtered_types=[EventTypes.Member],
@@ -337,6 +392,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id,
+ },
+ state_dict,
+ )
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
+ group,
+ [(EventTypes.Member, e5.state_key)],
+ filtered_types=[EventTypes.Member],
+ )
+
+ self.assertEqual(is_all, True)
+ self.assertDictEqual(
+ {
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
@@ -345,8 +414,22 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_cache,
+ group, [(EventTypes.Member, e5.state_key)], filtered_types=None
+ )
+
+ self.assertEqual(is_all, False)
+ self.assertDictEqual({}, state_dict)
+
+ (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+ self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
)
self.assertEqual(is_all, True)
- self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+ self.assertDictEqual(
+ {
+ (e5.type, e5.state_key): e5.event_id,
+ },
+ state_dict,
+ )
diff --git a/tests/test_server.py b/tests/test_server.py
index 895e490406..ef74544e93 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -8,7 +8,7 @@ from synapse.http.server import JsonResource
from synapse.util import Clock
from tests import unittest
-from tests.server import make_request, setup_test_homeserver
+from tests.server import make_request, render, setup_test_homeserver
class JsonResourceTests(unittest.TestCase):
@@ -37,7 +37,7 @@ class JsonResourceTests(unittest.TestCase):
)
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
- request.render(res)
+ render(request, res, self.reactor)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
@@ -55,7 +55,7 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo")
- request.render(res)
+ render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'500')
@@ -78,13 +78,8 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo")
- request.render(res)
+ render(request, res, self.reactor)
- # No error has been raised yet
- self.assertTrue("code" not in channel.result)
-
- # Advance time, now there's an error
- self.reactor.advance(1)
self.assertEqual(channel.result["code"], b'500')
def test_callback_synapseerror(self):
@@ -100,7 +95,7 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo")
- request.render(res)
+ render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'403')
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -121,7 +116,7 @@ class JsonResourceTests(unittest.TestCase):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foobar")
- request.render(res)
+ render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'400')
self.assertEqual(channel.json_body["error"], "Unrecognized request")
diff --git a/tests/test_state.py b/tests/test_state.py
index 96fdb8636c..452a123c3a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.auth import Auth
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.events import FrozenEvent
from synapse.state import StateHandler, StateResolutionHandler
@@ -117,6 +117,9 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
+ def get_room_version(self, room_id):
+ return RoomVersions.V1
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -176,7 +179,9 @@ class StateTestCase(unittest.TestCase):
def test_branch_no_conflict(self):
graph = Graph(
nodes={
- "START": DictObj(type=EventTypes.Create, state_key="", depth=1),
+ "START": DictObj(
+ type=EventTypes.Create, state_key="", content={}, depth=1,
+ ),
"A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3),
"C": DictObj(type=EventTypes.Name, state_key="", depth=3),
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 45a78338d6..8d8ce0cab9 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,7 +21,7 @@ from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server
import tests.unittest
-from tests.utils import setup_test_homeserver
+from tests.utils import create_room, setup_test_homeserver
logger = logging.getLogger(__name__)
@@ -36,6 +36,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_builder_factory = self.hs.get_event_builder_factory()
self.store = self.hs.get_datastore()
+ yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+
@defer.inlineCallbacks
def test_filtering(self):
#
diff --git a/tests/unittest.py b/tests/unittest.py
index e6afe3b96d..d852e2465a 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,8 @@ import logging
from mock import Mock
+from canonicaljson import json
+
import twisted
import twisted.logger
from twisted.trial import unittest
@@ -241,11 +243,15 @@ class HomeserverTestCase(TestCase):
method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such).
- content (bytes): The body of the request.
+ content (bytes or dict): The body of the request. JSON-encoded, if
+ a dict.
Returns:
A synapse.http.site.SynapseRequest.
"""
+ if isinstance(content, dict):
+ content = json.dumps(content).encode('utf8')
+
return make_request(method, path, content)
def render(self, request):
diff --git a/tests/utils.py b/tests/utils.py
index 6f8b1de3e7..9f7ff94575 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -24,6 +24,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor
+from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.federation.transport import server
from synapse.http.server import HttpServer
@@ -93,7 +94,8 @@ def setupdb():
@defer.inlineCallbacks
def setup_test_homeserver(
- cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs
+ cleanup_func, name="test", datastore=None, config=None, reactor=None,
+ homeserverToUse=HomeServer, **kargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -192,7 +194,7 @@ def setup_test_homeserver(
config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
if datastore is None:
- hs = HomeServer(
+ hs = homeserverToUse(
name,
config=config,
db_config=config.database_config,
@@ -235,7 +237,7 @@ def setup_test_homeserver(
hs.setup()
else:
- hs = HomeServer(
+ hs = homeserverToUse(
name,
db_pool=None,
datastore=datastore,
@@ -538,3 +540,32 @@ class DeferredMockCallable(object):
"Expected not to received any calls, got:\n"
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
)
+
+
+@defer.inlineCallbacks
+def create_room(hs, room_id, creator_id):
+ """Creates and persist a creation event for the given room
+
+ Args:
+ hs
+ room_id (str)
+ creator_id (str)
+ """
+
+ store = hs.get_datastore()
+ event_builder_factory = hs.get_event_builder_factory()
+ event_creation_handler = hs.get_event_creation_handler()
+
+ builder = event_builder_factory.new({
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": creator_id,
+ "room_id": room_id,
+ "content": {},
+ })
+
+ event, context = yield event_creation_handler.create_new_client_event(
+ builder
+ )
+
+ yield store.persist_event(event, context)
|