diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix
new file mode 100644
index 0000000000..f49c600173
--- /dev/null
+++ b/changelog.d/7384.bugfix
@@ -0,0 +1 @@
+Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
diff --git a/changelog.d/7457.feature b/changelog.d/7457.feature
new file mode 100644
index 0000000000..7ad767bf71
--- /dev/null
+++ b/changelog.d/7457.feature
@@ -0,0 +1 @@
+Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs).
diff --git a/changelog.d/7465.bugfix b/changelog.d/7465.bugfix
new file mode 100644
index 0000000000..1cbe50caa5
--- /dev/null
+++ b/changelog.d/7465.bugfix
@@ -0,0 +1 @@
+Prevent rooms with 0 members or with invalid version strings from breaking group queries.
\ No newline at end of file
diff --git a/changelog.d/7491.misc b/changelog.d/7491.misc
new file mode 100644
index 0000000000..50eb226db7
--- /dev/null
+++ b/changelog.d/7491.misc
@@ -0,0 +1 @@
+Move event stream handling out of slave store.
diff --git a/changelog.d/7505.misc b/changelog.d/7505.misc
new file mode 100644
index 0000000000..26114a3744
--- /dev/null
+++ b/changelog.d/7505.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.event_auth`.
diff --git a/changelog.d/7511.bugfix b/changelog.d/7511.bugfix
new file mode 100644
index 0000000000..cf8bc69c6f
--- /dev/null
+++ b/changelog.d/7511.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug that broke the update remote profile background process.
diff --git a/changelog.d/7513.misc b/changelog.d/7513.misc
new file mode 100644
index 0000000000..2ea7373e29
--- /dev/null
+++ b/changelog.d/7513.misc
@@ -0,0 +1 @@
+Add type hints to room member handler.
diff --git a/changelog.d/7514.doc b/changelog.d/7514.doc
new file mode 100644
index 0000000000..981168c7e8
--- /dev/null
+++ b/changelog.d/7514.doc
@@ -0,0 +1 @@
+Improve the formatting of `reverse_proxy.md`.
diff --git a/changelog.d/7516.misc b/changelog.d/7516.misc
new file mode 100644
index 0000000000..94b0fd49b2
--- /dev/null
+++ b/changelog.d/7516.misc
@@ -0,0 +1 @@
+Add a worker store for search insertion, required for moving event persistence off master.
diff --git a/changelog.d/7518.misc b/changelog.d/7518.misc
new file mode 100644
index 0000000000..f6e143fe1c
--- /dev/null
+++ b/changelog.d/7518.misc
@@ -0,0 +1 @@
+Fix typing annotations in `tests.replication`.
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 82bd5d1cdf..cbb8269568 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -34,97 +34,107 @@ the reverse proxy and the homeserver.
### nginx
- server {
- listen 443 ssl;
- listen [::]:443 ssl;
- server_name matrix.example.com;
-
- location /_matrix {
- proxy_pass http://localhost:8008;
- proxy_set_header X-Forwarded-For $remote_addr;
- # Nginx by default only allows file uploads up to 1M in size
- # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
- client_max_body_size 10M;
- }
- }
-
- server {
- listen 8448 ssl default_server;
- listen [::]:8448 ssl default_server;
- server_name example.com;
-
- location / {
- proxy_pass http://localhost:8008;
- proxy_set_header X-Forwarded-For $remote_addr;
- }
- }
-
-> **NOTE**: Do not add a `/` after the port in `proxy_pass`, otherwise nginx will
+```
+server {
+ listen 443 ssl;
+ listen [::]:443 ssl;
+ server_name matrix.example.com;
+
+ location /_matrix {
+ proxy_pass http://localhost:8008;
+ proxy_set_header X-Forwarded-For $remote_addr;
+ # Nginx by default only allows file uploads up to 1M in size
+ # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
+ client_max_body_size 10M;
+ }
+}
+
+server {
+ listen 8448 ssl default_server;
+ listen [::]:8448 ssl default_server;
+ server_name example.com;
+
+ location / {
+ proxy_pass http://localhost:8008;
+ proxy_set_header X-Forwarded-For $remote_addr;
+ }
+}
+```
+
+**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will
canonicalise/normalise the URI.
### Caddy 1
- matrix.example.com {
- proxy /_matrix http://localhost:8008 {
- transparent
- }
- }
+```
+matrix.example.com {
+ proxy /_matrix http://localhost:8008 {
+ transparent
+ }
+}
- example.com:8448 {
- proxy / http://localhost:8008 {
- transparent
- }
- }
+example.com:8448 {
+ proxy / http://localhost:8008 {
+ transparent
+ }
+}
+```
### Caddy 2
- matrix.example.com {
- reverse_proxy /_matrix/* http://localhost:8008
- }
+```
+matrix.example.com {
+ reverse_proxy /_matrix/* http://localhost:8008
+}
- example.com:8448 {
- reverse_proxy http://localhost:8008
- }
+example.com:8448 {
+ reverse_proxy http://localhost:8008
+}
+```
### Apache
- <VirtualHost *:443>
- SSLEngine on
- ServerName matrix.example.com;
+```
+<VirtualHost *:443>
+ SSLEngine on
+ ServerName matrix.example.com;
- AllowEncodedSlashes NoDecode
- ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
- ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
- </VirtualHost>
+ AllowEncodedSlashes NoDecode
+ ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
+ ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
+</VirtualHost>
- <VirtualHost *:8448>
- SSLEngine on
- ServerName example.com;
+<VirtualHost *:8448>
+ SSLEngine on
+ ServerName example.com;
- AllowEncodedSlashes NoDecode
- ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
- ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
- </VirtualHost>
+ AllowEncodedSlashes NoDecode
+ ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
+ ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
+</VirtualHost>
+```
-> **NOTE**: ensure the `nocanon` options are included.
+**NOTE**: ensure the `nocanon` options are included.
### HAProxy
- frontend https
- bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
+```
+frontend https
+ bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
- # Matrix client traffic
- acl matrix-host hdr(host) -i matrix.example.com
- acl matrix-path path_beg /_matrix
+ # Matrix client traffic
+ acl matrix-host hdr(host) -i matrix.example.com
+ acl matrix-path path_beg /_matrix
- use_backend matrix if matrix-host matrix-path
+ use_backend matrix if matrix-host matrix-path
- frontend matrix-federation
- bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
- default_backend matrix
+frontend matrix-federation
+ bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
+ default_backend matrix
- backend matrix
- server matrix 127.0.0.1:8008
+backend matrix
+ server matrix 127.0.0.1:8008
+```
## Homeserver Configuration
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 2e3add7ac5..ab801108ca 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -122,6 +122,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
+from synapse.storage.data_stores.main.search import SearchWorkerStore
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
@@ -451,6 +452,7 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
+ SearchWorkerStore,
BaseSlavedStore,
):
def __init__(self, database, db_conn, hs):
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index aea3985a5f..1b13e84425 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -270,7 +270,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
- def get_exlusive_user_regexes(self):
+ def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 5a5b568a95..c582355146 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Set, Tuple
+from typing import List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -29,18 +29,19 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersion,
)
-from synapse.types import UserID, get_domain_from_id
+from synapse.events import EventBase
+from synapse.types import StateMap, UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(
room_version_obj: RoomVersion,
- event,
- auth_events,
- do_sig_check=True,
- do_size_check=True,
-):
+ event: EventBase,
+ auth_events: StateMap[EventBase],
+ do_sig_check: bool = True,
+ do_size_check: bool = True,
+) -> None:
""" Checks if this event is correctly authed.
Args:
@@ -189,7 +190,7 @@ def check(
logger.debug("Allowing! %s", event)
-def _check_size_limits(event):
+def _check_size_limits(event: EventBase) -> None:
def too_big(field):
raise EventSizeError("%s too large" % (field,))
@@ -207,13 +208,18 @@ def _check_size_limits(event):
too_big("event")
-def _can_federate(event, auth_events):
+def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
creation_event = auth_events.get((EventTypes.Create, ""))
+ # There should always be a creation event, but if not don't federate.
+ if not creation_event:
+ return False
return creation_event.content.get("m.federate", True) is True
-def _is_membership_change_allowed(event, auth_events):
+def _is_membership_change_allowed(
+ event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
membership = event.content["membership"]
# Check if this is the room creator joining:
@@ -339,21 +345,25 @@ def _is_membership_change_allowed(event, auth_events):
raise AuthError(500, "Unknown membership %s" % membership)
-def _check_event_sender_in_room(event, auth_events):
+def _check_event_sender_in_room(
+ event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key)
- return _check_joined_room(member_event, event.user_id, event.room_id)
+ _check_joined_room(member_event, event.user_id, event.room_id)
-def _check_joined_room(member, user_id, room_id):
+def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
if not member or member.membership != Membership.JOIN:
raise AuthError(
403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
)
-def get_send_level(etype, state_key, power_levels_event):
+def get_send_level(
+ etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
+) -> int:
"""Get the power level required to send an event of a given type
The federation spec [1] refers to this as "Required Power Level".
@@ -361,13 +371,13 @@ def get_send_level(etype, state_key, power_levels_event):
https://matrix.org/docs/spec/server_server/unstable.html#definitions
Args:
- etype (str): type of event
- state_key (str|None): state_key of state event, or None if it is not
+ etype: type of event
+ state_key: state_key of state event, or None if it is not
a state event.
- power_levels_event (synapse.events.EventBase|None): power levels event
+ power_levels_event: power levels event
in force at this point in the room
Returns:
- int: power level required to send this event.
+ power level required to send this event.
"""
if power_levels_event:
@@ -388,7 +398,7 @@ def get_send_level(etype, state_key, power_levels_event):
return int(send_level)
-def _can_send_event(event, auth_events):
+def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
power_levels_event = _get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
@@ -410,7 +420,9 @@ def _can_send_event(event, auth_events):
return True
-def check_redaction(room_version_obj: RoomVersion, event, auth_events):
+def check_redaction(
+ room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> bool:
"""Check whether the event sender is allowed to redact the target event.
Returns:
@@ -442,7 +454,9 @@ def check_redaction(room_version_obj: RoomVersion, event, auth_events):
raise AuthError(403, "You don't have permission to redact events")
-def _check_power_levels(room_version_obj, event, auth_events):
+def _check_power_levels(
+ room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> None:
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
@@ -473,7 +487,7 @@ def _check_power_levels(room_version_obj, event, auth_events):
("redact", None),
("kick", None),
("invite", None),
- ]
+ ] # type: List[Tuple[str, Optional[str]]]
old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)):
@@ -503,12 +517,12 @@ def _check_power_levels(room_version_obj, event, auth_events):
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
- old_level = int(old_loc[level_to_check])
+ old_level = int(old_loc[level_to_check]) # type: Optional[int]
else:
old_level = None
if level_to_check in new_loc:
- new_level = int(new_loc[level_to_check])
+ new_level = int(new_loc[level_to_check]) # type: Optional[int]
else:
new_level = None
@@ -534,21 +548,21 @@ def _check_power_levels(room_version_obj, event, auth_events):
)
-def _get_power_level_event(auth_events):
+def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
return auth_events.get((EventTypes.PowerLevels, ""))
-def get_user_power_level(user_id, auth_events):
+def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
"""Get a user's power level
Args:
- user_id (str): user's id to look up in power_levels
- auth_events (dict[(str, str), synapse.events.EventBase]):
+ user_id: user's id to look up in power_levels
+ auth_events:
state in force at this point in the room (or rather, a subset of
it including at least the create event and power levels event.
Returns:
- int: the user's power level in this room.
+ the user's power level in this room.
"""
power_level_event = _get_power_level_event(auth_events)
if power_level_event:
@@ -574,7 +588,7 @@ def get_user_power_level(user_id, auth_events):
return 0
-def _get_named_level(auth_events, name, default):
+def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
power_level_event = _get_power_level_event(auth_events)
if not power_level_event:
@@ -587,7 +601,7 @@ def _get_named_level(auth_events, name, default):
return default
-def _verify_third_party_invite(event, auth_events):
+def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
"""
Validates that the invite event is authorized by a previous third-party invite.
@@ -662,7 +676,7 @@ def get_public_keys(invite_event):
return public_keys
-def auth_types_for_event(event) -> Set[Tuple[str, str]]:
+def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 524281d2f1..75b39e878c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -80,7 +80,9 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled
+ self._sso_enabled = (
+ hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
+ )
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 178f263439..4ba8c7fda5 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -311,7 +311,7 @@ class OidcHandler:
``ClientAuth`` to authenticate with the client with its ID and secret.
Args:
- code: The autorization code we got from the callback.
+ code: The authorization code we got from the callback.
Returns:
A dict containing various tokens.
@@ -497,11 +497,14 @@ class OidcHandler:
return UserInfo(claims)
async def handle_redirect_request(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> None:
+ self,
+ request: SynapseRequest,
+ client_redirect_url: bytes,
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
"""Handle an incoming request to /login/sso/redirect
- It redirects the browser to the authorization endpoint with a few
+ It returns a redirect to the authorization endpoint with a few
parameters:
- ``client_id``: the client ID set in ``oidc_config.client_id``
@@ -511,24 +514,32 @@ class OidcHandler:
- ``state``: a random string
- ``nonce``: a random string
- In addition to redirecting the client, we are setting a cookie with
+ In addition generating a redirect URL, we are setting a cookie with
a signed macaroon token containing the state, the nonce and the
client_redirect_url params. Those are then checked when the client
comes back from the provider.
-
Args:
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
when everything is done
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
+
+ Returns:
+ The redirect URL to the authorization endpoint.
+
"""
state = generate_token()
nonce = generate_token()
cookie = self._generate_oidc_session_token(
- state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
+ state=state,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url.decode(),
+ ui_auth_session_id=ui_auth_session_id,
)
request.addCookie(
SESSION_COOKIE_NAME,
@@ -541,7 +552,7 @@ class OidcHandler:
metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
- uri = prepare_grant_uri(
+ return prepare_grant_uri(
authorization_endpoint,
client_id=self._client_auth.client_id,
response_type="code",
@@ -550,8 +561,6 @@ class OidcHandler:
state=state,
nonce=nonce,
)
- request.redirect(uri)
- finish_request(request)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
@@ -625,7 +634,11 @@ class OidcHandler:
# Deserialize the session token and verify it.
try:
- nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
+ (
+ nonce,
+ client_redirect_url,
+ ui_auth_session_id,
+ ) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
@@ -678,15 +691,21 @@ class OidcHandler:
return
# and finally complete the login
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url
- )
+ if ui_auth_session_id:
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
+ else:
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url
+ )
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
+ ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
@@ -702,6 +721,8 @@ class OidcHandler:
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
@@ -718,12 +739,19 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
+ if ui_auth_session_id:
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (ui_auth_session_id,)
+ )
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
return macaroon.serialize()
- def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
+ def _verify_oidc_session_token(
+ self, session: str, state: str
+ ) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
@@ -734,7 +762,7 @@ class OidcHandler:
state: The state the OIDC provider gave back
Returns:
- The nonce and the client_redirect_url for this session
+ The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -744,17 +772,27 @@ class OidcHandler:
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+ # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+ # to always satisfy this.
+ v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
- # Extract the `nonce` and `client_redirect_url` from the token
+ # Extract the `nonce`, `client_redirect_url`, and maybe the
+ # `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
+ try:
+ ui_auth_session_id = self._get_value_from_macaroon(
+ macaroon, "ui_auth_session_id"
+ ) # type: Optional[str]
+ except ValueError:
+ ui_auth_session_id = None
- return nonce, client_redirect_url
+ return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
@@ -773,7 +811,7 @@ class OidcHandler:
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
- raise Exception("No %s caveat in macaroon" % (key,))
+ raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4ddeba4c97..e51e1c32fe 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,13 +17,16 @@
import abc
import logging
+from typing import Dict, Iterable, List, Optional, Tuple, Union
from six.moves import http_client
from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.types import Collection, RoomID, UserID
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -74,84 +77,84 @@ class RoomMemberHandler(object):
self.base_handler = BaseHandler(hs)
@abc.abstractmethod
- async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Optional[dict]:
"""Try and join a room that this server is not in
Args:
- requester (Requester)
- remote_room_hosts (list[str]): List of servers that can be used
- to join via.
- room_id (str): Room that we are trying to join
- user (UserID): User who is trying to join
- content (dict): A dict that should be used as the content of the
- join event.
-
- Returns:
- Deferred
+ requester
+ remote_room_hosts: List of servers that can be used to join via.
+ room_id: Room that we are trying to join
+ user: User who is trying to join
+ content: A dict that should be used as the content of the join event.
"""
raise NotImplementedError()
@abc.abstractmethod
async def _remote_reject_invite(
- self, requester, remote_room_hosts, room_id, target, content
- ):
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ target: UserID,
+ content: dict,
+ ) -> dict:
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
Args:
- requester (Requester)
- remote_room_hosts (list[str]): List of servers to use to try and
- reject invite
- room_id (str)
- target (UserID): The user rejecting the invite
- content (dict): The content for the rejection event
+ requester
+ remote_room_hosts: List of servers to use to try and reject invite
+ room_id
+ target: The user rejecting the invite
+ content: The content for the rejection event
Returns:
- Deferred[dict]: A dictionary to be returned to the client, may
+ A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
"""
raise NotImplementedError()
@abc.abstractmethod
- async def _user_joined_room(self, target, room_id):
+ async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the
room.
Args:
- target (UserID)
- room_id (str)
-
- Returns:
- None
+ target
+ room_id
"""
raise NotImplementedError()
@abc.abstractmethod
- async def _user_left_room(self, target, room_id):
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
Args:
- target (UserID)
- room_id (str)
-
- Returns:
- None
+ target
+ room_id
"""
raise NotImplementedError()
async def _local_membership_update(
self,
- requester,
- target,
- room_id,
- membership,
+ requester: Requester,
+ target: UserID,
+ room_id: str,
+ membership: str,
prev_event_ids: Collection[str],
- txn_id=None,
- ratelimit=True,
- content=None,
- require_consent=True,
- ):
+ txn_id: Optional[str] = None,
+ ratelimit: bool = True,
+ content: Optional[dict] = None,
+ require_consent: bool = True,
+ ) -> EventBase:
user_id = target.to_string()
if content is None:
@@ -214,16 +217,13 @@ class RoomMemberHandler(object):
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
- ):
+ ) -> None:
"""Copies the tags and direct room state from one room to another.
Args:
- old_room_id (str)
- new_room_id (str)
- user_id (str)
-
- Returns:
- Deferred[None]
+ old_room_id: The room ID of the old room.
+ new_room_id: The room ID of the new room.
+ user_id: The user's ID.
"""
# Retrieve user account data for predecessor room
user_account_data, _ = await self.store.get_account_data_for_user(user_id)
@@ -253,17 +253,17 @@ class RoomMemberHandler(object):
async def update_membership(
self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
- ):
+ requester: Requester,
+ target: UserID,
+ room_id: str,
+ action: str,
+ txn_id: Optional[str] = None,
+ remote_room_hosts: Optional[List[str]] = None,
+ third_party_signed: Optional[dict] = None,
+ ratelimit: bool = True,
+ content: Optional[dict] = None,
+ require_consent: bool = True,
+ ) -> Union[EventBase, Optional[dict]]:
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -284,17 +284,17 @@ class RoomMemberHandler(object):
async def _update_membership(
self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
- ):
+ requester: Requester,
+ target: UserID,
+ room_id: str,
+ action: str,
+ txn_id: Optional[str] = None,
+ remote_room_hosts: Optional[List[str]] = None,
+ third_party_signed: Optional[dict] = None,
+ ratelimit: bool = True,
+ content: Optional[dict] = None,
+ require_consent: bool = True,
+ ) -> Union[EventBase, Optional[dict]]:
content_specified = bool(content)
if content is None:
content = {}
@@ -468,12 +468,11 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
- res = await self._remote_reject_invite(
+ return await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
- return res
- res = await self._local_membership_update(
+ return await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@@ -484,9 +483,10 @@ class RoomMemberHandler(object):
content=content,
require_consent=require_consent,
)
- return res
- async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+ async def transfer_room_state_on_room_upgrade(
+ self, old_room_id: str, room_id: str
+ ) -> None:
"""Upon our server becoming aware of an upgraded room, either by upgrading a room
ourselves or joining one, we can transfer over information from the previous room.
@@ -494,12 +494,8 @@ class RoomMemberHandler(object):
well as migrating the room directory state.
Args:
- old_room_id (str): The ID of the old room
-
- room_id (str): The ID of the new room
-
- Returns:
- Deferred
+ old_room_id: The ID of the old room
+ room_id: The ID of the new room
"""
logger.info("Transferring room state from %s to %s", old_room_id, room_id)
@@ -526,17 +522,16 @@ class RoomMemberHandler(object):
# Remove the old room from those groups
await self.store.remove_room_from_group(group_id, old_room_id)
- async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+ async def copy_user_state_on_room_upgrade(
+ self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
+ ) -> None:
"""Copy user-specific information when they join a new room when that new room is the
result of a room upgrade
Args:
- old_room_id (str): The ID of upgraded room
- new_room_id (str): The ID of the new room
- user_ids (Iterable[str]): User IDs to copy state for
-
- Returns:
- Deferred
+ old_room_id: The ID of upgraded room
+ new_room_id: The ID of the new room
+ user_ids: User IDs to copy state for
"""
logger.debug(
@@ -566,17 +561,23 @@ class RoomMemberHandler(object):
)
continue
- async def send_membership_event(self, requester, event, context, ratelimit=True):
+ async def send_membership_event(
+ self,
+ requester: Requester,
+ event: EventBase,
+ context: EventContext,
+ ratelimit: bool = True,
+ ):
"""
Change the membership status of a user in a room.
Args:
- requester (Requester): The local user who requested the membership
+ requester: The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
- event (SynapseEvent): The membership event.
+ event: The membership event.
context: The context of the event.
- ratelimit (bool): Whether to rate limit this request.
+ ratelimit: Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
@@ -636,7 +637,9 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
- async def _can_guest_join(self, current_state_ids):
+ async def _can_guest_join(
+ self, current_state_ids: Dict[Tuple[str, str], str]
+ ) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@@ -653,12 +656,14 @@ class RoomMemberHandler(object):
and guest_access.content["guest_access"] == "can_join"
)
- async def lookup_room_alias(self, room_alias):
+ async def lookup_room_alias(
+ self, room_alias: RoomAlias
+ ) -> Tuple[RoomID, List[str]]:
"""
Get the room ID associated with a room alias.
Args:
- room_alias (RoomAlias): The alias to look up.
+ room_alias: The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
@@ -682,24 +687,25 @@ class RoomMemberHandler(object):
return RoomID.from_string(room_id), servers
- async def _get_inviter(self, user_id, room_id):
+ async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]:
invite = await self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id
)
if invite:
return UserID.from_string(invite.sender)
+ return None
async def do_3pid_invite(
self,
- room_id,
- inviter,
- medium,
- address,
- id_server,
- requester,
- txn_id,
- id_access_token=None,
- ):
+ room_id: str,
+ inviter: UserID,
+ medium: str,
+ address: str,
+ id_server: str,
+ requester: Requester,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> None:
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@@ -748,15 +754,15 @@ class RoomMemberHandler(object):
async def _make_and_store_3pid_invite(
self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- user,
- txn_id,
- id_access_token=None,
- ):
+ requester: Requester,
+ id_server: str,
+ medium: str,
+ address: str,
+ room_id: str,
+ user: UserID,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> None:
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
@@ -830,7 +836,9 @@ class RoomMemberHandler(object):
txn_id=txn_id,
)
- async def _is_host_in_room(self, current_state_ids):
+ async def _is_host_in_room(
+ self, current_state_ids: Dict[Tuple[str, str], str]
+ ) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -852,7 +860,7 @@ class RoomMemberHandler(object):
return False
- async def _is_server_notice_room(self, room_id):
+ async def _is_server_notice_room(self, room_id: str) -> bool:
if self._server_notices_mxid is None:
return False
user_ids = await self.store.get_users_in_room(room_id)
@@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
- async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
+ async def _is_remote_room_too_complex(
+ self, room_id: str, remote_room_hosts: List[str]
+ ) -> Optional[bool]:
"""
Check if complexity of a remote room is too great.
Args:
- room_id (str)
- remote_room_hosts (list[str])
+ room_id
+ remote_room_hosts
Returns: bool of whether the complexity is too great, or None
if unable to be fetched
@@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
return None
- async def _is_local_room_too_complex(self, room_id):
+ async def _is_local_room_too_complex(self, room_id: str) -> bool:
"""
Check if the complexity of a local room is too great.
Args:
- room_id (str)
-
- Returns: bool
+ room_id: The room ID to check for complexity.
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = await self.store.get_room_complexity(room_id)
return complexity["v1"] > max_complexity
- async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> None:
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
async def _remote_reject_invite(
- self, requester, remote_room_hosts, room_id, target, content
- ):
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ target: UserID,
+ content: dict,
+ ) -> dict:
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
@@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler):
await self.store.locally_reject_invite(target.to_string(), room_id)
return {}
- async def _user_joined_room(self, target, room_id):
+ async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
"""
- return user_joined_room(self.distributor, target, room_id)
+ user_joined_room(self.distributor, target, room_id)
- async def _user_left_room(self, target, room_id):
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
- return user_left_room(self.distributor, target, room_id)
+ user_left_room(self.distributor, target, room_id)
- async def forget(self, user, room_id):
+ async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
member = await self.state_handler.get_current_state(
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 0fc54349ab..5c776cc0be 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import List, Optional
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
@@ -22,6 +23,7 @@ from synapse.replication.http.membership import (
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
+from synapse.types import Requester, UserID
logger = logging.getLogger(__name__)
@@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
- async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ user: UserID,
+ content: dict,
+ ) -> Optional[dict]:
"""Implements RoomMemberHandler._remote_join
"""
if len(remote_room_hosts) == 0:
@@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret
async def _remote_reject_invite(
- self, requester, remote_room_hosts, room_id, target, content
- ):
+ self,
+ requester: Requester,
+ remote_room_hosts: List[str],
+ room_id: str,
+ target: UserID,
+ content: dict,
+ ) -> dict:
"""Implements RoomMemberHandler._remote_reject_invite
"""
return await self._remote_reject_client(
@@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
content=content,
)
- async def _user_joined_room(self, target, room_id):
+ async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
"""
- return await self._notify_change_client(
+ await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="joined"
)
- async def _user_left_room(self, target, room_id):
+ async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
- return await self._notify_change_client(
+ await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b313720a4b..1a1a50a24f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,11 +15,6 @@
# limitations under the License.
import logging
-from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import (
- EventsStreamCurrentStateRow,
- EventsStreamEventRow,
-)
from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
from synapse.storage.data_stores.main.event_push_actions import (
EventPushActionsWorkerStore,
@@ -35,7 +30,6 @@ from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__)
@@ -62,11 +56,6 @@ class SlavedEventStore(
BaseSlavedStore,
):
def __init__(self, database: Database, db_conn, hs):
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
-
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
@@ -92,81 +81,3 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
-
- def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "events":
- self._stream_id_gen.advance(token)
- for row in rows:
- self._process_event_stream_row(token, row)
- elif stream_name == "backfill":
- self._backfill_id_gen.advance(-token)
- for row in rows:
- self.invalidate_caches_for_event(
- -token,
- row.event_id,
- row.room_id,
- row.type,
- row.state_key,
- row.redacts,
- row.relates_to,
- backfilled=True,
- )
- return super().process_replication_rows(stream_name, instance_name, token, rows)
-
- def _process_event_stream_row(self, token, row):
- data = row.data
-
- if row.type == EventsStreamEventRow.TypeId:
- self.invalidate_caches_for_event(
- token,
- data.event_id,
- data.room_id,
- data.type,
- data.state_key,
- data.redacts,
- data.relates_to,
- backfilled=False,
- )
- elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
-
- if data.type == EventTypes.Member:
- self.get_rooms_for_user_with_stream_ordering.invalidate(
- (data.state_key,)
- )
- else:
- raise Exception("Unknown events stream row type %s" % (row.type,))
-
- def invalidate_caches_for_event(
- self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
- self._invalidate_get_event_cache(event_id)
-
- self.get_latest_event_ids_in_room.invalidate((room_id,))
-
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
-
- if not backfilled:
- self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
-
- if redacts:
- self._invalidate_get_event_cache(redacts)
-
- if etype == EventTypes.Member:
- self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
- self.get_invited_rooms_for_local_user.invalidate((state_key,))
-
- if relates_to:
- self.get_relations_for_event.invalidate_many((relates_to,))
- self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
- self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 5d5816d7eb..6adb19463a 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,19 +15,11 @@
# limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
-from synapse.storage.database import Database
-from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
- self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id"
- )
- super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
-
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b48a6a3e91..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
import logging
from collections import namedtuple
-from typing import Any, Awaitable, Callable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
-StreamRow = Tuple
+StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@@ -533,32 +546,63 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
- db_query_to_update_function(self._update_function),
+ self._update_function,
+ )
+
+ async def _update_function(
+ self, instance_name: str, from_token: int, to_token: int, limit: int
+ ) -> StreamUpdateResult:
+ limited = False
+ global_results = await self.store.get_updated_global_account_data(
+ from_token, to_token, limit
)
- async def _update_function(self, from_token, to_token, limit):
- global_results, room_results = await self.store.get_all_updated_account_data(
- from_token, from_token, to_token, limit
+ # if the global results hit the limit, we'll need to limit the room results to
+ # the same stream token.
+ if len(global_results) >= limit:
+ to_token = global_results[-1][0]
+ limited = True
+
+ room_results = await self.store.get_updated_room_account_data(
+ from_token, to_token, limit
)
- results = list(room_results)
- results.extend(
- (stream_id, user_id, None, account_data_type)
+ # likewise, if the room results hit the limit, limit the global results to
+ # the same stream token.
+ if len(room_results) >= limit:
+ to_token = room_results[-1][0]
+ limited = True
+
+ # convert the global results to the right format, and limit them to the to_token
+ # at the same time
+ global_rows = (
+ (stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
+ if stream_id <= to_token
+ )
+
+ # we know that the room_results are already limited to `to_token` so no need
+ # for a check on `stream_id` here.
+ room_rows = (
+ (stream_id, (user_id, room_id, account_data_type))
+ for stream_id, user_id, room_id, account_data_type in room_results
)
- return results
+ # we need to return a sorted list, so merge them together.
+ updates = list(heapq.merge(room_rows, global_rows))
+ return updates, to_token, limited
class GroupServerStream(Stream):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index de7eca21f8..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
- def on_GET(self, request: SynapseRequest):
+ async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
- sso_url = self.get_sso_url(client_redirect_url)
+ sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url)
finish_request(request)
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
+ request: The client request to redirect.
client_redirect_url: the URL that we should redirect the
client to when everything is done
@@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._cas_handler = hs.get_cas_handler()
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._cas_handler.get_redirect_url(
{"redirectUrl": client_redirect_url}
).encode("ascii")
@@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
- def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
-class OIDCRedirectServlet(RestServlet):
+class OIDCRedirectServlet(BaseSSORedirectServlet):
"""Implementation for /login/sso/redirect for the OIDC login flow."""
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
def __init__(self, hs):
self._oidc_handler = hs.get_oidc_handler()
- async def on_GET(self, request):
- args = request.args
- if b"redirectUrl" not in args:
- return 400, "Redirect URL not specified for SSO auth"
- client_redirect_url = args[b"redirectUrl"][0]
- await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 24dd3d3e96..7bca1326d5 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
# SSO configuration.
- self._saml_enabled = hs.config.saml2_enabled
- if self._saml_enabled:
- self._saml_handler = hs.get_saml_handler()
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
+ self._saml_enabled = hs.config.saml2_enabled
+ if self._saml_enabled:
+ self._saml_handler = hs.get_saml_handler()
+ self._oidc_enabled = hs.config.oidc_enabled
+ if self._oidc_enabled:
+ self._oidc_handler = hs.get_oidc_handler()
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
@@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
)
elif self._saml_enabled:
- client_redirect_url = ""
+ client_redirect_url = b""
sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session
)
+ elif self._oidc_enabled:
+ client_redirect_url = b""
+ sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url, session
+ )
+
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 5df9dce79d..4b4763c701 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -24,7 +24,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
- ChainedIdGenerator,
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
@@ -125,19 +124,6 @@ class DataStore(
self._clock = hs.get_clock()
self.database_engine = database.engine
- self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
- )
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
- )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@@ -164,9 +150,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- )
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..f9eef1b78e 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
import abc
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
- def get_all_updated_account_data(
- self, last_global_id, last_room_id, current_id, limit
- ):
- """Get all the client account_data that has changed on the server
+ async def get_updated_global_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+
Args:
- last_global_id(int): The position to fetch from for top level data
- last_room_id(int): The position to fetch from for per room data
- current_id(int): The position to fetch up to.
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
Returns:
- A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, and type string.
+ A list of tuples of stream_id int, user_id string,
+ and type string.
"""
- if last_room_id == current_id and last_global_id == current_id:
- return defer.succeed(([], []))
+ if last_id == current_id:
+ return []
- def get_updated_account_data_txn(txn):
+ def get_updated_global_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_global_id, current_id, limit))
- global_results = txn.fetchall()
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return await self.db.runInteraction(
+ "get_updated_global_account_data", get_updated_global_account_data_txn
+ )
+
+ async def get_updated_room_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+ Args:
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
+ Returns:
+ A list of tuples of stream_id int, user_id string,
+ room_id string and type string.
+ """
+ if last_id == current_id:
+ return []
+
+ def get_updated_room_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_room_id, current_id, limit))
- room_results = txn.fetchall()
- return global_results, room_results
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
- return self.db.runInteraction(
- "get_all_updated_account_data_txn", get_updated_account_data_txn
+ return await self.db.runInteraction(
+ "get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index efbc06c796..7a1fe8cdd2 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
- # We precompie a regex constructed from all the regexes that the AS's
+ # We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
regex.pattern
for service in services_cache
- for regex in service.get_exlusive_user_regexes()
+ for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 342a87a46b..eac5a4e55b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,8 +16,13 @@
import itertools
import logging
-from typing import Any, Iterable, Optional
+from typing import Any, Iterable, Optional, Tuple
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+ EventsStreamCurrentStateRow,
+ EventsStreamEventRow,
+)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "caches":
+ if stream_name == "events":
+ for row in rows:
+ self._process_event_stream_row(token, row)
+ elif stream_name == "backfill":
+ for row in rows:
+ self._invalidate_caches_for_event(
+ -token,
+ row.event_id,
+ row.room_id,
+ row.type,
+ row.state_key,
+ row.redacts,
+ row.relates_to,
+ backfilled=True,
+ )
+ elif stream_name == "caches":
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
@@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def _process_event_stream_row(self, token, row):
+ data = row.data
+
+ if row.type == EventsStreamEventRow.TypeId:
+ self._invalidate_caches_for_event(
+ token,
+ data.event_id,
+ data.room_id,
+ data.type,
+ data.state_key,
+ data.redacts,
+ data.relates_to,
+ backfilled=False,
+ )
+ elif row.type == EventsStreamCurrentStateRow.TypeId:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.data.room_id, token
+ )
+
+ if data.type == EventTypes.Member:
+ self.get_rooms_for_user_with_stream_ordering.invalidate(
+ (data.state_key,)
+ )
+ else:
+ raise Exception("Unknown events stream row type %s" % (row.type,))
+
+ def _invalidate_caches_for_event(
+ self,
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ state_key,
+ redacts,
+ relates_to,
+ backfilled,
+ ):
+ self._invalidate_get_event_cache(event_id)
+
+ self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+ if not backfilled:
+ self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+ if redacts:
+ self._invalidate_get_event_cache(redacts)
+
+ if etype == EventTypes.Member:
+ self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+ if relates_to:
+ self.get_relations_for_event.invalidate_many((relates_to,))
+ self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
+
+ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ cache_func = getattr(self, cache_name, None)
+ if not cache_func:
+ return
+
+ cache_func.invalidate(keys)
+ await self.db.runInteraction(
+ "invalidate_cache_and_stream",
+ self._send_invalidation_to_replication,
+ cache_func.__name__,
+ keys,
+ )
+
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 970c31bd05..9130b74eb5 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -37,8 +37,10 @@ from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
@@ -74,6 +76,31 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+ if hs.config.worker_app is None:
+ # We are the process in charge of generating stream ids for events,
+ # so instantiate ID generators based on the database
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ extra_tables=[("local_invites", "stream_id")],
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ # Another process is in charge of persisting events and generating
+ # stream IDs: rely on the replication streams to let us know which
+ # IDs we can process.
+ self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
+
self._get_event_cache = Cache(
"*getEvent*",
keylen=3,
@@ -85,6 +112,14 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == "events":
+ self._stream_id_gen.advance(token)
+ elif stream_name == "backfill":
+ self._backfill_id_gen.advance(-token)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
def get_received_ts(self, event_id):
"""Get received_ts (when it was persisted) for the event.
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 0963e6c250..fb1361f1c1 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group",
)
- def get_rooms_in_group(self, group_id, include_private=False):
+ def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+ """Retrieve the rooms that belong to a given group. Does not return rooms that
+ lack members.
+
+ Args:
+ group_id: The ID of the group to query for rooms
+ include_private: Whether to return private rooms in results
+
+ Returns:
+ Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
+ form of:
+
+ {
+ "room_id": "!a_room_id:example.com", # The ID of the room
+ "is_public": False # Whether this is a public room or not
+ }
+ """
# TODO: Pagination
- keyvalues = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
+ def _get_rooms_in_group_txn(txn):
+ sql = """
+ SELECT room_id, is_public FROM group_rooms
+ WHERE group_id = ?
+ AND room_id IN (
+ SELECT group_rooms.room_id FROM group_rooms
+ LEFT JOIN room_stats_current ON
+ group_rooms.room_id = room_stats_current.room_id
+ AND joined_members > 0
+ AND local_users_in_room > 0
+ LEFT JOIN rooms ON
+ group_rooms.room_id = rooms.room_id
+ AND (room_version <> '') = ?
+ )
+ """
+ args = [group_id, False]
- return self.db.simple_select_list(
- table="group_rooms",
- keyvalues=keyvalues,
- retcols=("room_id", "is_public"),
- desc="get_rooms_in_group",
- )
+ if not include_private:
+ sql += " AND is_public = ?"
+ args += [True]
+
+ txn.execute(sql, args)
+
+ return [
+ {"room_id": room_id, "is_public": is_public}
+ for room_id, is_public in txn
+ ]
- def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+ return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+
+ def get_rooms_for_summary_by_category(
+ self, group_id: str, include_private: bool = False,
+ ):
"""Get the rooms and categories that should be included in a summary request
- Returns ([rooms], [categories])
+ Args:
+ group_id: The ID of the group to query the summary for
+ include_private: Whether to return private rooms in results
+
+ Returns:
+ Deferred[Tuple[List, Dict]]: A tuple containing:
+
+ * A list of dictionaries with the keys:
+ * "room_id": str, the room ID
+ * "is_public": bool, whether the room is public
+ * "category_id": str|None, the category ID if set, else None
+ * "order": int, the sort order of rooms
+
+ * A dictionary with the key:
+ * category_id (str): a dictionary with the keys:
+ * "is_public": bool, whether the category is public
+ * "profile": str, the category profile
+ * "order": int, the sort order of rooms in this category
"""
def _get_rooms_for_summary_txn(txn):
@@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
SELECT room_id, is_public, category_id, room_order
FROM group_summary_rooms
WHERE group_id = ?
+ AND room_id IN (
+ SELECT group_rooms.room_id FROM group_rooms
+ LEFT JOIN room_stats_current ON
+ group_rooms.room_id = room_stats_current.room_id
+ AND joined_members > 0
+ AND local_users_in_room > 0
+ LEFT JOIN rooms ON
+ group_rooms.room_id = rooms.room_id
+ AND (room_version <> '') = ?
+ )
"""
if not include_private:
sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
+ txn.execute(sql, (group_id, False, True))
else:
- txn.execute(sql, (group_id,))
+ txn.execute(sql, (group_id, False))
rooms = [
{
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 2b52cf9c1a..bfc9369f0b 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- values={
+ updatevalues={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index b3faafa0a4..ef8f40959f 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,19 +16,23 @@
import abc
import logging
+from typing import Union
from canonicaljson import json
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
from synapse.storage.database import Database
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -64,6 +68,7 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
+ EventsWorkerStore,
SQLBaseStore,
):
"""This is an abstract base class where subclasses must implement
@@ -77,6 +82,15 @@ class PushRulesWorkerStore(
def __init__(self, database: Database, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+ if hs.config.worker.worker_app is None:
+ self._push_rules_stream_id_gen = ChainedIdGenerator(
+ self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ else:
+ self._push_rules_stream_id_gen = SlavedIdTracker(
+ db_conn, "push_rules_stream", "stream_id"
+ )
+
push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn,
"push_rules_stream",
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index ee75b92344..13f49d8060 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -37,7 +37,55 @@ SearchEntry = namedtuple(
)
-class SearchBackgroundUpdateStore(SQLBaseStore):
+class SearchWorkerStore(SQLBaseStore):
+ def store_search_entries_txn(self, txn, entries):
+ """Add entries to the search table
+
+ Args:
+ txn (cursor):
+ entries (iterable[SearchEntry]):
+ entries to be added to the table
+ """
+ if not self.hs.config.enable_search:
+ return
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = (
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ args = (
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ entry.value,
+ entry.stream_ordering,
+ entry.origin_server_ts,
+ )
+ for entry in entries
+ )
+
+ txn.executemany(sql, args)
+
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key, value)"
+ " VALUES (?,?,?,?)"
+ )
+ args = (
+ (entry.event_id, entry.room_id, entry.key, entry.value)
+ for entry in entries
+ )
+
+ txn.executemany(sql, args)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+
+class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
return num_rows
- def store_search_entries_txn(self, txn, entries):
- """Add entries to the search table
-
- Args:
- txn (cursor):
- entries (iterable[SearchEntry]):
- entries to be added to the table
- """
- if not self.hs.config.enable_search:
- return
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
-
- args = (
- (
- entry.event_id,
- entry.room_id,
- entry.key,
- entry.value,
- entry.stream_ordering,
- entry.origin_server_ts,
- )
- for entry in entries
- )
-
- txn.executemany(sql, args)
-
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- args = (
- (entry.event_id, entry.room_id, entry.key, entry.value)
- for entry in entries
- )
-
- txn.executemany(sql, args)
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 86d04ea9ac..f89ce0bed2 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -166,6 +166,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
+ self._table = table
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
@@ -204,6 +205,16 @@ class ChainedIdGenerator(object):
return self._current_max, self.chained_generator.get_current_token()
+ def advance(self, token: int):
+ """Stub implementation for advancing the token when receiving updates
+ over replication; raises an exception as this instance should be the
+ only source of updates.
+ """
+
+ raise Exception(
+ "Attempted to advance token on source for table %r", self._table
+ )
+
class MultiWriterIdGenerator:
"""An ID generator that tracks a stream that can have multiple writers.
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 61963aa90d..1bb25ab684 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -292,11 +292,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@defer.inlineCallbacks
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
- req = Mock(spec=["addCookie", "redirect", "finish"])
- yield defer.ensureDeferred(
+ req = Mock(spec=["addCookie"])
+ url = yield defer.ensureDeferred(
self.handler.handle_redirect_request(req, b"http://client/redirect")
)
- url = req.redirect.call_args[0][0]
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -382,7 +381,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
nonce = "nonce"
client_redirect_url = "http://client/redirect"
session = self.handler._generate_oidc_session_token(
- state=state, nonce=nonce, client_redirect_url=client_redirect_url,
+ state=state,
+ nonce=nonce,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
)
request.getCookie.return_value = session
@@ -472,7 +474,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Mismatching session
session = self.handler._generate_oidc_session_token(
- state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+ state="state",
+ nonce="nonce",
+ client_redirect_url="http://client/redirect",
+ ui_auth_session_id=None,
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0fee8a71c4..1a88c7fb80 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -17,11 +17,12 @@ from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
-from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
+from tests.server import FakeTransport
+
from ._base import BaseSlavedStoreTestCase
USER_ID = "@feeling:test"
@@ -240,6 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# limit the replication rate
repl_transport = self._server_transport
+ assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
@@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.message",
key=None,
internal={},
- state=None,
depth=None,
prev_events=[],
auth_events=[],
@@ -362,15 +363,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
self.event_id += 1
-
- if state is not None:
- state_ids = {key: e.event_id for key, e in state.items()}
- context = EventContext.with_state(
- state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
- )
- else:
- state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(event))
+ state_handler = self.hs.get_state_handler()
+ context = self.get_success(state_handler.compute_event_context(event))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
new file mode 100644
index 0000000000..6a5116dd2a
--- /dev/null
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.tcp.streams._base import (
+ _STREAM_UPDATE_TARGET_ROW_COUNT,
+ AccountDataStream,
+)
+
+from tests.replication._base import BaseStreamTestCase
+
+
+class AccountDataStreamTestCase(BaseStreamTestCase):
+ def test_update_function_room_account_data_limit(self):
+ """Test replication with many room account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", update, {})
+ )
+ updates.append(update)
+
+ # also one global update
+ self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertEqual(row.room_id, "test_room")
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.global")
+ self.assertIsNone(row.room_id)
+
+ self.assertEqual([], received_rows)
+
+ def test_update_function_global_account_data_limit(self):
+ """Test replication with many global account data updates
+ """
+ store = self.hs.get_datastore()
+
+ # generate lots of account data updates
+ updates = []
+ for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
+ update = "m.test_type.%i" % (i,)
+ self.get_success(store.add_account_data_for_user("test_user", update, {}))
+ updates.append(update)
+
+ # also one per-room update
+ self.get_success(
+ store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
+ )
+
+ # tell the notifier to catch up to avoid duplicate rows.
+ # workaround for https://github.com/matrix-org/synapse/issues/7360
+ # FIXME remove this when the above is fixed
+ self.replicate()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should have received all the expected rows in the right order
+ received_rows = self.test_handler.received_rdata_rows
+
+ for t in updates:
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertEqual(stream_name, AccountDataStream.NAME)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, t)
+ self.assertIsNone(row.room_id)
+
+ (stream_name, token, row) = received_rows.pop(0)
+ self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
+ self.assertEqual(row.data_type, "m.per_room")
+ self.assertEqual(row.room_id, "test_room")
+
+ self.assertEqual([], received_rows)
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index 7ddfd0a733..60c10a441a 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata(self):
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
- self.assertIsInstance(cmd, RdataCommand)
+ assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "events")
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
@@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata_batch(self):
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
- self.assertIsInstance(cmd, RdataCommand)
+ assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "presence")
self.assertEqual(cmd.instance_name, "master")
self.assertIsNone(cmd.token)
diff --git a/tox.ini b/tox.ini
index 203c648008..3bb4d45e2a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -180,6 +180,7 @@ commands = mypy \
synapse/api \
synapse/appservice \
synapse/config \
+ synapse/event_auth.py \
synapse/events/spamcheck.py \
synapse/federation \
synapse/handlers/auth.py \
@@ -187,6 +188,8 @@ commands = mypy \
synapse/handlers/directory.py \
synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
+ synapse/handlers/room_member.py \
+ synapse/handlers/room_member_worker.py \
synapse/handlers/saml_handler.py \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \
@@ -204,7 +207,7 @@ commands = mypy \
synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
- tests/replication/tcp/streams \
+ tests/replication \
tests/test_utils \
tests/rest/client/v2_alpha/test_auth.py \
tests/util/test_stream_change_cache.py
|