diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index f924116819..7597fbc864 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -56,6 +56,7 @@ class OIDCConfig(Config):
self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
self.oidc_jwks_uri = oidc_config.get("jwks_uri")
self.oidc_skip_verification = oidc_config.get("skip_verification", False)
+ self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
ump_config = oidc_config.get("user_mapping_provider", {})
@@ -159,6 +160,14 @@ class OIDCConfig(Config):
#
#skip_verification: true
+ # Whether to fetch the user profile from the userinfo endpoint. Valid
+ # values are: "auto" or "userinfo_endpoint".
+ #
+ # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
+ # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
+ #
+ #user_profile_method: "userinfo_endpoint"
+
# Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
# of failing. This could be used if switching from password logins to OIDC. Defaults to false.
#
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 9ddb8b546b..ad37b93c02 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -18,7 +18,7 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
-from typing import List
+from typing import List, Optional
from unpaddedbase64 import encode_base64
@@ -177,8 +177,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
- self.tls_certificate = None
- self.tls_private_key = None
+ self.tls_certificate = None # type: Optional[crypto.X509]
+ self.tls_private_key = None # type: Optional[crypto.PKey]
def is_disk_cert_valid(self, allow_self_signed=True):
"""
@@ -226,12 +226,12 @@ class TlsConfig(Config):
days_remaining = (expires_on - now).days
return days_remaining
- def read_certificate_from_disk(self, require_cert_and_key):
+ def read_certificate_from_disk(self, require_cert_and_key: bool):
"""
Read the certificates and private key from disk.
Args:
- require_cert_and_key (bool): set to True to throw an error if the certificate
+ require_cert_and_key: set to True to throw an error if the certificate
and key file are not given
"""
if require_cert_and_key:
@@ -479,13 +479,13 @@ class TlsConfig(Config):
}
)
- def read_tls_certificate(self):
+ def read_tls_certificate(self) -> crypto.X509:
"""Reads the TLS certificate from the configured file, and returns it
Also checks if it is self-signed, and warns if so
Returns:
- OpenSSL.crypto.X509: the certificate
+ The certificate
"""
cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path)
@@ -504,11 +504,11 @@ class TlsConfig(Config):
return cert
- def read_tls_private_key(self):
+ def read_tls_private_key(self) -> crypto.PKey:
"""Reads the TLS private key from the configured file, and returns it
Returns:
- OpenSSL.crypto.PKey: the private key
+ The private key
"""
private_key_path = self.tls_private_key_file
logger.info("Loading TLS key from %s", private_key_path)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 24329dd0e3..02f11e1209 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
Callable,
Dict,
List,
- Match,
Optional,
Tuple,
Union,
@@ -825,14 +824,14 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
return False
-def _acl_entry_matches(server_name: str, acl_entry: str) -> Match:
+def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
- return regex.match(server_name)
+ return bool(regex.match(server_name))
class FederationHandlerRegistry:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 62aa9a2da8..6f15c68240 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -383,7 +383,7 @@ class DirectoryHandler(BaseHandler):
"""
creator = await self.store.get_room_alias_creator(alias.to_string())
- if creator is not None and creator == user_id:
+ if creator == user_id:
return True
# Resolve the alias to the corresponding room.
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 19cd652675..05ac86e697 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -96,6 +96,7 @@ class OidcHandler:
self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._user_profile_method = hs.config.oidc_user_profile_method # type: str
self._client_auth = ClientAuth(
hs.config.oidc_client_id,
hs.config.oidc_client_secret,
@@ -196,11 +197,11 @@ class OidcHandler:
% (m["response_types_supported"],)
)
- # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ # Ensure there's a userinfo endpoint to fetch from if it is required.
if self._uses_userinfo:
if m.get("userinfo_endpoint") is None:
raise ValueError(
- 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ 'provider has no "userinfo_endpoint", even though it is required'
)
else:
# If we're not using userinfo, we need a valid jwks to validate the ID token
@@ -220,8 +221,10 @@ class OidcHandler:
``access_token`` with the ``userinfo_endpoint``.
"""
- # Maybe that should be user-configurable and not inferred?
- return "openid" not in self._scopes
+ return (
+ "openid" not in self._scopes
+ or self._user_profile_method == "userinfo_endpoint"
+ )
async def load_metadata(self) -> OpenIDProviderMetadata:
"""Load and validate the provider metadata.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d5f7c78edf..f1a6699cd4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -962,8 +962,6 @@ class RoomCreationHandler(BaseHandler):
try:
random_string = stringutils.random_string(18)
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
- if isinstance(gen_room_id, bytes):
- gen_room_id = gen_room_id.decode("utf-8")
await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8feba8c90a..5ec36f591d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -642,7 +642,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def send_membership_event(
self,
- requester: Requester,
+ requester: Optional[Requester],
event: EventBase,
context: EventContext,
ratelimit: bool = True,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bfe2583002..260ec19b41 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -87,7 +87,7 @@ class SyncConfig:
class TimelineBatch:
prev_batch = attr.ib(type=StreamToken)
events = attr.ib(type=List[EventBase])
- limited = attr.ib(bool)
+ limited = attr.ib(type=bool)
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 996a31a9ec..09ed74f6ce 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -257,7 +257,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
@@ -406,7 +406,7 @@ class JsonResource(DirectServeJsonResource):
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
- callback_return = raw_callback_return
+ callback_return = raw_callback_return # type: ignore
return callback_return
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 144506c8f2..0fc2ea609e 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import os.path
import sys
@@ -89,14 +88,7 @@ class LogContextObserver:
context = current_context()
# Copy the context information to the log event.
- if context is not None:
- context.copy_to_twisted_log_entry(event)
- else:
- # If there's no logging context, not even the root one, we might be
- # starting up or it might be from non-Synapse code. Log it as if it
- # came from the root logger.
- event["request"] = None
- event["scope"] = None
+ context.copy_to_twisted_log_entry(event)
self.observer(event)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 709ace01e5..3a68ce636f 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
import logging
import re
-from typing import Any, Dict, List, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Union
from synapse.events import EventBase
from synapse.types import UserID
@@ -181,7 +181,7 @@ class PushRuleEvaluatorForEvent:
return r.search(body)
- def _get_value(self, dotted_key: str) -> str:
+ def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b323841f73..e92da7b263 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -251,10 +251,9 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
- import txredisapi
-
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
+ lazyConnection,
)
logger.info(
@@ -271,7 +270,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = txredisapi.lazyConnection(
+ outbound_redis_connection = lazyConnection(
+ reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0b0d204e64..a509e599c2 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,10 +51,11 @@ import fcntl
import logging
import struct
from inspect import isawaitable
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
+from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
@@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
- self.time_we_closed = None # When we requested the connection be closed
+ # When we requested the connection be closed
+ self.time_we_closed = None # type: Optional[int]
- self.received_ping = False # Have we reecived a ping from the other side
+ self.received_ping = False # Have we received a ping from the other side
self.state = ConnectionStates.CONNECTING
@@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
- self._send_ping_loop = None
+ self._send_ping_loop = None # type: Optional[task.LoopingCall]
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index f225e533de..de19705c1f 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
import txredisapi
@@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password
return p
+
+
+def lazyConnection(
+ reactor,
+ host: str = "localhost",
+ port: int = 6379,
+ dbid: Optional[int] = None,
+ reconnect: bool = True,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ connectTimeout: Optional[int] = None,
+ replyTimeout: Optional[int] = None,
+ convertNumbers: bool = True,
+) -> txredisapi.RedisProtocol:
+ """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
+ reactor.
+ """
+
+ isLazy = True
+ poolsize = 1
+
+ uuid = "%s:%d" % (host, port)
+ factory = txredisapi.RedisFactory(
+ uuid,
+ dbid,
+ poolsize,
+ isLazy,
+ txredisapi.ConnectionHandler,
+ charset,
+ password,
+ replyTimeout,
+ convertNumbers,
+ )
+ factory.continueTrying = reconnect
+ for x in range(poolsize):
+ reactor.connectTCP(host, port, factory, connectTimeout)
+
+ return factory.handler
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 31082bb16a..5b0900aa3c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -738,7 +738,7 @@ def _make_state_cache_entry(
# failing that, look for the closest match.
prev_group = None
- delta_ids = None
+ delta_ids = None # type: Optional[StateMap[str]]
for old_group, old_state in state_groups_ids.items():
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..4bb2b9c28c 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.frozenutils import frozendict_json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = encode_json(
+ pruned_json = frozendict_json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
- pruned_json = encode_json(
+ pruned_json = frozendict_json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..78e645592f 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -52,16 +52,6 @@ event_counter = Counter(
)
-def encode_json(json_object):
- """
- Encode a Python object as JSON and return it in a Unicode string.
- """
- out = frozendict_json_encoder.encode(json_object)
- if isinstance(out, bytes):
- out = out.decode("utf8")
- return out
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -743,7 +733,9 @@ class PersistEventsStore:
logger.exception("")
raise
- metadata_json = encode_json(event.internal_metadata.get_dict())
+ metadata_json = frozendict_json_encoder.encode(
+ event.internal_metadata.get_dict()
+ )
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -797,10 +789,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
- "internal_metadata": encode_json(
+ "internal_metadata": frozendict_json_encoder.encode(
event.internal_metadata.get_dict()
),
- "json": encode_json(event_dict(event)),
+ "json": frozendict_json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 37249f1e3f..1d27439536 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -546,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
- ) -> Tuple[int, int, str]:
+ ) -> Optional[Tuple[int, int, str]]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 48efbb5067..c92cd4a6ba 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -421,7 +421,7 @@ class MultiWriterIdGenerator:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
- new_cur = None
+ new_cur = None # type: Optional[int]
if self._unfinished_ids:
# If there are unfinished IDs then the new position will be the
|