From c951fbedcb81895c199c1f4cfe2251d6c3a7b5f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 13:09:41 -0500 Subject: MSC3873: Escape keys when flattening dicts. (#15004) This disambiguates keys which attempt to match fields with a dot in them (e.g. m.relates_to). Disabled by default behind an experimental configuration flag. --- synapse/config/experimental.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'synapse/config') diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 53c0682dfd..5e3a889081 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -169,6 +169,11 @@ class ExperimentalConfig(Config): # MSC3925: do not replace events with their edits self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + # MSC3873: Disambiguate event_match keys. + self.msc3783_escape_event_match_key = experimental.get( + "msc3783_escape_event_match_key", False + ) + # MSC3952: Intentional mentions self.msc3952_intentional_mentions = experimental.get( "msc3952_intentional_mentions", False -- cgit 1.5.1 From 14be78d492fc31e743e9e5855ddb8b4c9520985a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 10 Feb 2023 12:37:07 -0500 Subject: Support for MSC3758: exact_event_match push condition (#14964) This specifies to search for an exact value match, instead of string globbing. It only works across non-compound JSON values (null, boolean, integer, and strings). --- changelog.d/14964.feature | 1 + rust/benches/evaluator.rs | 65 +++++++++++--- rust/src/push/evaluator.rs | 69 +++++++++++---- rust/src/push/mod.rs | 83 +++++++++++++++++ stubs/synapse/synapse_rust/push.pyi | 7 +- synapse/config/experimental.py | 5 ++ synapse/push/bulk_push_rule_evaluator.py | 18 ++-- synapse/types/__init__.py | 2 + tests/push/test_push_rule_evaluator.py | 147 ++++++++++++++++++++++++++++++- 9 files changed, 356 insertions(+), 41 deletions(-) create mode 100644 changelog.d/14964.feature (limited to 'synapse/config') diff --git a/changelog.d/14964.feature b/changelog.d/14964.feature new file mode 100644 index 0000000000..13c0bc193b --- /dev/null +++ b/changelog.d/14964.feature @@ -0,0 +1 @@ +Implement the experimental `exact_event_match` push rule condition from [MSC3758](https://github.com/matrix-org/matrix-spec-proposals/pull/3758). diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 35f7a50bce..229553ebf8 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -16,6 +16,7 @@ use std::collections::BTreeSet; use synapse::push::{ evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, + SimpleJsonValue, }; use test::Bencher; @@ -24,9 +25,18 @@ extern crate test; #[bench] fn bench_match_exact(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -43,6 +53,7 @@ fn bench_match_exact(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -63,9 +74,18 @@ fn bench_match_exact(b: &mut Bencher) { #[bench] fn bench_match_word(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -82,6 +102,7 @@ fn bench_match_word(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -102,9 +123,18 @@ fn bench_match_word(b: &mut Bencher) { #[bench] fn bench_match_word_miss(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -121,6 +151,7 @@ fn bench_match_word_miss(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); @@ -141,9 +172,18 @@ fn bench_match_word_miss(b: &mut Bencher) { #[bench] fn bench_eval_message(b: &mut Bencher) { let flattened_keys = [ - ("type".to_string(), "m.text".to_string()), - ("room_id".to_string(), "!room:server".to_string()), - ("content.body".to_string(), "test message".to_string()), + ( + "type".to_string(), + SimpleJsonValue::Str("m.text".to_string()), + ), + ( + "room_id".to_string(), + SimpleJsonValue::Str("!room:server".to_string()), + ), + ( + "content.body".to_string(), + SimpleJsonValue::Str("test message".to_string()), + ), ] .into_iter() .collect(); @@ -160,6 +200,7 @@ fn bench_eval_message(b: &mut Bencher) { true, vec![], false, + false, ) .unwrap(); diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index ec7a8c4453..dd6b4343ec 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -22,8 +22,8 @@ use regex::Regex; use super::{ utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType}, - Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition, - RelatedEventMatchCondition, + Action, Condition, EventMatchCondition, ExactEventMatchCondition, FilteredPushRules, + KnownCondition, RelatedEventMatchCondition, SimpleJsonValue, }; lazy_static! { @@ -61,9 +61,9 @@ impl RoomVersionFeatures { /// Allows running a set of push rules against a particular event. #[pyclass] pub struct PushRuleEvaluator { - /// A mapping of "flattened" keys to string values in the event, e.g. + /// A mapping of "flattened" keys to simple JSON values in the event, e.g. /// includes things like "type" and "content.msgtype". - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, /// The "content.body", if any. body: String, @@ -87,7 +87,7 @@ pub struct PushRuleEvaluator { /// The related events, indexed by relation type. Flattened in the same manner as /// `flattened_keys`. - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, @@ -98,6 +98,9 @@ pub struct PushRuleEvaluator { /// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same /// flag as MSC1767 (extensible events core). msc3931_enabled: bool, + + /// If MSC3758 (exact_event_match push rule condition) is enabled. + msc3758_exact_event_match: bool, } #[pymethods] @@ -106,22 +109,23 @@ impl PushRuleEvaluator { #[allow(clippy::too_many_arguments)] #[new] pub fn py_new( - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, has_mentions: bool, user_mentions: BTreeSet, room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, related_event_match_enabled: bool, room_version_feature_flags: Vec, msc3931_enabled: bool, + msc3758_exact_event_match: bool, ) -> Result { - let body = flattened_keys - .get("content.body") - .cloned() - .unwrap_or_default(); + let body = match flattened_keys.get("content.body") { + Some(SimpleJsonValue::Str(s)) => s.clone(), + _ => String::new(), + }; Ok(PushRuleEvaluator { flattened_keys, @@ -136,6 +140,7 @@ impl PushRuleEvaluator { related_event_match_enabled, room_version_feature_flags, msc3931_enabled, + msc3758_exact_event_match, }) } @@ -252,6 +257,9 @@ impl PushRuleEvaluator { KnownCondition::EventMatch(event_match) => { self.match_event_match(event_match, user_id)? } + KnownCondition::ExactEventMatch(exact_event_match) => { + self.match_exact_event_match(exact_event_match)? + } KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } @@ -337,7 +345,9 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(haystack) = self.flattened_keys.get(&*event_match.key) { + let haystack = if let Some(SimpleJsonValue::Str(haystack)) = + self.flattened_keys.get(&*event_match.key) + { haystack } else { return Ok(false); @@ -355,6 +365,27 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `exact_event_match` condition. (MSC3758) + fn match_exact_event_match( + &self, + exact_event_match: &ExactEventMatchCondition, + ) -> Result { + // First check if the feature is enabled. + if !self.msc3758_exact_event_match { + return Ok(false); + } + + let value = &exact_event_match.value; + + let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) { + haystack + } else { + return Ok(false); + }; + + Ok(haystack == &**value) + } + /// Evaluates a `related_event_match` condition. (MSC3664) fn match_related_event_match( &self, @@ -410,7 +441,7 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(haystack) = event.get(&**key) { + let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) { haystack } else { return Ok(false); @@ -455,7 +486,10 @@ impl PushRuleEvaluator { #[test] fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); - flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + flattened_keys.insert( + "content.body".to_string(), + SimpleJsonValue::Str("foo bar bob hello".to_string()), + ); let evaluator = PushRuleEvaluator::py_new( flattened_keys, false, @@ -468,6 +502,7 @@ fn push_rule_evaluator() { true, vec![], true, + true, ) .unwrap(); @@ -482,7 +517,10 @@ fn test_requires_room_version_supports_condition() { use crate::push::{PushRule, PushRules}; let mut flattened_keys = BTreeMap::new(); - flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); + flattened_keys.insert( + "content.body".to_string(), + SimpleJsonValue::Str("foo bar bob hello".to_string()), + ); let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( flattened_keys, @@ -496,6 +534,7 @@ fn test_requires_room_version_supports_condition() { false, flags, true, + true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 3c4f876cab..79e519fe11 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -56,7 +56,9 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use anyhow::{Context, Error}; use log::warn; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; +use pyo3::types::{PyBool, PyLong, PyString}; use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; @@ -248,6 +250,36 @@ impl<'de> Deserialize<'de> for Action { } } +/// A simple JSON values (string, int, boolean, or null). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum SimpleJsonValue { + Str(String), + Int(i64), + Bool(bool), + Null, +} + +impl<'source> FromPyObject<'source> for SimpleJsonValue { + fn extract(ob: &'source PyAny) -> PyResult { + if let Ok(s) = ::try_from(ob) { + Ok(SimpleJsonValue::Str(s.to_string())) + // A bool *is* an int, ensure we try bool first. + } else if let Ok(b) = ::try_from(ob) { + Ok(SimpleJsonValue::Bool(b.extract()?)) + } else if let Ok(i) = ::try_from(ob) { + Ok(SimpleJsonValue::Int(i.extract()?)) + } else if ob.is_none() { + Ok(SimpleJsonValue::Null) + } else { + Err(PyTypeError::new_err(format!( + "Can't convert from {} to SimpleJsonValue", + ob.get_type().name()? + ))) + } + } +} + /// A condition used in push rules to match against an event. /// /// We need this split as `serde` doesn't give us the ability to have a @@ -267,6 +299,8 @@ pub enum Condition { #[serde(tag = "kind")] pub enum KnownCondition { EventMatch(EventMatchCondition), + #[serde(rename = "com.beeper.msc3758.exact_event_match")] + ExactEventMatch(ExactEventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), #[serde(rename = "org.matrix.msc3952.is_user_mention")] @@ -309,6 +343,13 @@ pub struct EventMatchCondition { pub pattern_type: Option>, } +/// The body of a [`Condition::ExactEventMatch`] +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ExactEventMatchCondition { + pub key: Cow<'static, str>, + pub value: Cow<'static, SimpleJsonValue>, +} + /// The body of a [`Condition::RelatedEventMatch`] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct RelatedEventMatchCondition { @@ -542,6 +583,48 @@ fn test_deserialize_unstable_msc3931_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3758_condition() { + // A string condition should work. + let json = + r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":"foo"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); + + // A boolean condition should work. + let json = + r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":true}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); + + // An integer condition should work. + let json = r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":1}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); + + // A null condition should work + let json = + r#"{"kind":"com.beeper.msc3758.exact_event_match","key":"content.value","value":null}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::ExactEventMatch(_)) + )); +} + #[test] fn test_deserialize_unstable_msc3952_user_condition() { let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 754acab2f9..328f681a29 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -14,7 +14,7 @@ from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union -from synapse.types import JsonDict +from synapse.types import JsonDict, SimpleJsonValue class PushRule: @property @@ -56,17 +56,18 @@ def get_base_rule_ids() -> Collection[str]: ... class PushRuleEvaluator: def __init__( self, - flattened_keys: Mapping[str, str], + flattened_keys: Mapping[str, SimpleJsonValue], has_mentions: bool, user_mentions: Set[str], room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - related_events_flattened: Mapping[str, Mapping[str, str]], + related_events_flattened: Mapping[str, Mapping[str, SimpleJsonValue]], related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, + msc3758_exact_event_match: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 5e3a889081..6ac2f0c10d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -169,6 +169,11 @@ class ExperimentalConfig(Config): # MSC3925: do not replace events with their edits self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + # MSC3758: exact_event_match push rule condition + self.msc3758_exact_event_match = experimental.get( + "msc3758_exact_event_match", False + ) + # MSC3873: Disambiguate event_match keys. self.msc3783_escape_event_match_key = experimental.get( "msc3783_escape_event_match_key", False diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 39d2f88f03..8568aca528 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -43,6 +43,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types import SimpleJsonValue from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func @@ -256,13 +257,15 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + async def _related_events( + self, event: EventBase + ) -> Dict[str, Dict[str, SimpleJsonValue]]: """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation Returns: Mapping of relation type to flattened events. """ - related_events: Dict[str, Dict[str, str]] = {} + related_events: Dict[str, Dict[str, SimpleJsonValue]] = {} if self._related_event_match_enabled: related_event_id = event.content.get("m.relates_to", {}).get("event_id") relation_type = event.content.get("m.relates_to", {}).get("rel_type") @@ -425,6 +428,7 @@ class BulkPushRuleEvaluator: self._related_event_match_enabled, event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag + self.hs.config.experimental.msc3758_exact_event_match, ) users = rules_by_user.keys() @@ -501,15 +505,15 @@ StateGroup = Union[object, int] def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, - result: Optional[Dict[str, str]] = None, + result: Optional[Dict[str, SimpleJsonValue]] = None, *, msc3783_escape_event_match_key: bool = False, -) -> Dict[str, str]: +) -> Dict[str, SimpleJsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, flatten it into a single layer dictionary by combining the keys & sub-keys. - Any (non-dictionary), non-string value is dropped. + String, integer, boolean, and null values are kept. All others are dropped. Transforms: @@ -538,8 +542,8 @@ def _flatten_dict( # nested fields. key = key.replace("\\", "\\\\").replace(".", "\\.") - if isinstance(value, str): - result[".".join(prefix + [key])] = value.lower() + if isinstance(value, (bool, str)) or type(value) is int or value is None: + result[".".join(prefix + [key])] = value elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below _flatten_dict( diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index f82d1cfc29..52e366c8ae 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -69,6 +69,8 @@ StateMap = Mapping[StateKey, T] MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. +# A "simple" (canonical) JSON value. +SimpleJsonValue = Optional[Union[str, int, bool]] # A JSON-serialisable dict. JsonDict = Dict[str, Any] # A JSON-serialisable mapping; roughly speaking an immutable JSONDict. diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 516b65cc3c..6603447341 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -57,7 +57,7 @@ class FlattenDictTestCase(unittest.TestCase): ) def test_non_string(self) -> None: - """Non-string items are dropped.""" + """Booleans, ints, and nulls should be kept while other items are dropped.""" input: Dict[str, Any] = { "woo": "woo", "foo": True, @@ -66,7 +66,9 @@ class FlattenDictTestCase(unittest.TestCase): "fuzz": [], "boo": {}, } - self.assertEqual({"woo": "woo"}, _flatten_dict(input)) + self.assertEqual( + {"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input) + ) def test_event(self) -> None: """Events can also be flattened.""" @@ -86,9 +88,9 @@ class FlattenDictTestCase(unittest.TestCase): ) expected = { "content.msgtype": "m.text", - "content.body": "hello world!", + "content.body": "Hello world!", "content.format": "org.matrix.custom.html", - "content.formatted_body": "

hello world!

", + "content.formatted_body": "

Hello world!

", "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", @@ -166,6 +168,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): related_event_match_enabled=True, room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, + msc3758_exact_event_match=True, ) def test_display_name(self) -> None: @@ -410,6 +413,142 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should not match before a newline", ) + def test_exact_event_match_string(self) -> None: + """Check that exact_event_match conditions work as expected for strings.""" + + # Test against a string value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": "foobaz", + } + self._assert_matches( + condition, + {"value": "foobaz"}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": "FoobaZ"}, + "values should match and be case-sensitive", + ) + self._assert_not_matches( + condition, + {"value": "test foobaz test"}, + "values must exactly match", + ) + value: Any + for value in (True, False, 1, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + # it should work on frozendicts too + self._assert_matches( + condition, + frozendict.frozendict({"value": "foobaz"}), + "values should match on frozendicts", + ) + + def test_exact_event_match_boolean(self) -> None: + """Check that exact_event_match conditions work as expected for booleans.""" + + # Test against a True boolean value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": True, + } + self._assert_matches( + condition, + {"value": True}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": False}, + "incorrect values should not match", + ) + for value in ("foobaz", 1, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + # Test against a False boolean value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": False, + } + self._assert_matches( + condition, + {"value": False}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": True}, + "incorrect values should not match", + ) + # Choose false-y values to ensure there's no type coercion. + for value in ("", 0, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + def test_exact_event_match_null(self) -> None: + """Check that exact_event_match conditions work as expected for null.""" + + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": None, + } + self._assert_matches( + condition, + {"value": None}, + "exact value should match", + ) + for value in ("foobaz", True, False, 1, 1.1, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + def test_exact_event_match_integer(self) -> None: + """Check that exact_event_match conditions work as expected for integers.""" + + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": 1, + } + self._assert_matches( + condition, + {"value": 1}, + "exact value should match", + ) + value: Any + for value in (1.1, -1, 0): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect values should not match", + ) + for value in ("1", True, False, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) -- cgit 1.5.1 From d0c713cc85f094c323b2ba3f02d8ac411a7f0705 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 10 Feb 2023 23:29:00 +0000 Subject: Return read-only collections from `@cached` methods (#13755) It's important that collections returned from `@cached` methods are not modified, otherwise future retrievals from the cache will return the modified collection. This applies to the return values from `@cached` methods and the values inside the dictionaries returned by `@cachedList` methods. It's not necessary for the dictionaries returned by `@cachedList` methods themselves to be read-only. Signed-off-by: Sean Quah Co-authored-by: David Robertson --- changelog.d/13755.misc | 1 + synapse/app/phone_stats_home.py | 4 ++-- synapse/config/room_directory.py | 6 +++--- synapse/events/builder.py | 6 +++--- synapse/federation/federation_server.py | 3 ++- synapse/handlers/directory.py | 6 +++--- synapse/handlers/receipts.py | 4 ++-- synapse/handlers/room.py | 2 +- synapse/handlers/sync.py | 4 ++-- synapse/push/bulk_push_rule_evaluator.py | 4 ++-- synapse/state/__init__.py | 2 +- synapse/storage/controllers/state.py | 6 +++--- synapse/storage/databases/main/account_data.py | 2 +- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/devices.py | 17 +++++++++------ synapse/storage/databases/main/directory.py | 4 ++-- synapse/storage/databases/main/end_to_end_keys.py | 25 +++++++++++++--------- synapse/storage/databases/main/event_federation.py | 11 ++++++---- .../storage/databases/main/monthly_active_users.py | 4 ++-- synapse/storage/databases/main/receipts.py | 10 +++++---- synapse/storage/databases/main/registration.py | 4 ++-- synapse/storage/databases/main/relations.py | 7 ++++-- synapse/storage/databases/main/roommember.py | 19 ++++++++-------- synapse/storage/databases/main/signatures.py | 6 +++--- synapse/storage/databases/main/tags.py | 8 ++++--- synapse/storage/databases/main/user_directory.py | 4 ++-- tests/rest/admin/test_server_notice.py | 4 ++-- 27 files changed, 98 insertions(+), 77 deletions(-) create mode 100644 changelog.d/13755.misc (limited to 'synapse/config') diff --git a/changelog.d/13755.misc b/changelog.d/13755.misc new file mode 100644 index 0000000000..662ee00e99 --- /dev/null +++ b/changelog.d/13755.misc @@ -0,0 +1 @@ +Re-type hint some collections as read-only. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 53db1e85b3..897dd3edac 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,7 +15,7 @@ import logging import math import resource import sys -from typing import TYPE_CHECKING, List, Sized, Tuple +from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple from prometheus_client import Gauge @@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None: @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 - current_mau_count_by_service = {} + current_mau_count_by_service: Mapping[str, int] = {} reserved_users: Sized = () store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 3ed236217f..8666c22f01 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Collection from matrix_common.regex import glob_to_regex @@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config): return False def is_publishing_room_allowed( - self, user_id: str, room_id: str, aliases: List[str] + self, user_id: str, room_id: str, aliases: Collection[str] ) -> bool: """Checks if the given user is allowed to publish the room @@ -122,7 +122,7 @@ class _RoomDirectoryRule: except Exception as e: raise ConfigError("Failed to parse glob into regex") from e - def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: + def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool: """Tests if this rule matches the given user_id, room_id and aliases. Args: diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 94dd1298e1..c82745275f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey @@ -103,7 +103,7 @@ class EventBuilder: async def build( self, - prev_event_ids: List[str], + prev_event_ids: Collection[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ) -> EventBase: @@ -136,7 +136,7 @@ class EventBuilder: format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. - prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8d36172484..6addc0bb65 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -23,6 +23,7 @@ from typing import ( Collection, Dict, List, + Mapping, Optional, Tuple, Union, @@ -1512,7 +1513,7 @@ class FederationHandlerRegistry: def _get_event_ids_for_partial_state_join( join_event: EventBase, prev_state_ids: StateMap[str], - summary: Dict[str, MemberSummary], + summary: Mapping[str, MemberSummary], ) -> Collection[str]: """Calculate state to be returned in a partial_state send_join diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index d31b0fbb17..a5798e9483 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -14,7 +14,7 @@ import logging import string -from typing import TYPE_CHECKING, Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence from typing_extensions import Literal @@ -486,7 +486,7 @@ class DirectoryHandler: ) if canonical_alias: # Ensure we do not mutate room_aliases. - room_aliases = room_aliases + [canonical_alias] + room_aliases = list(room_aliases) + [canonical_alias] if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases @@ -529,7 +529,7 @@ class DirectoryHandler: async def get_aliases_for_room( self, requester: Requester, room_id: str - ) -> List[str]: + ) -> Sequence[str]: """ Get a list of the aliases that currently point to this room on this server """ diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 04c61ae3dd..2bacdebfb5 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple from synapse.api.constants import EduTypes, ReceiptTypes from synapse.appservice import ApplicationService @@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): @staticmethod def filter_out_private_receipts( - rooms: List[JsonDict], user_id: str + rooms: Sequence[JsonDict], user_id: str ) -> List[JsonDict]: """ Filters a list of serialized receipts (as returned by /sync and /initialSync) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0e759b8a5d..060bbcb181 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1928,6 +1928,6 @@ class RoomShutdownHandler: return { "kicked_users": kicked_users, "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, + "local_aliases": list(aliases_for_room), "new_room_id": new_room_id, } diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 399685e5b7..4bae46158a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1519,7 +1519,7 @@ class SyncHandler: one_time_keys_count = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = ( + unused_fallback_key_types = list( await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) @@ -2301,7 +2301,7 @@ class SyncHandler: sync_result_builder: "SyncResultBuilder", room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[Dict[str, Dict[str, Any]]], + tags: Optional[Mapping[str, Mapping[str, Any]]], account_data: Mapping[str, JsonDict], always_include: bool = False, ) -> None: diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8568aca528..f6a5bffb0f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -22,6 +22,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -149,7 +150,7 @@ class BulkPushRuleEvaluator: # little, we can skip fetching a huge number of push rules in large rooms. # This helps make joins and leaves faster. if event.type == EventTypes.Member: - local_users = [] + local_users: Sequence[str] = [] # We never notify a user about their own actions. This is enforced in # `_action_for_event_by_user` in the loop over `rules_by_user`, but we # do the same check here to avoid unnecessary DB queries. @@ -184,7 +185,6 @@ class BulkPushRuleEvaluator: if event.type == EventTypes.Member and event.membership == Membership.INVITE: invited = event.state_key if invited and self.hs.is_mine_id(invited) and invited not in local_users: - local_users = list(local_users) local_users.append(invited) if not local_users: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index e877e6f1a1..4dc25df67e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -226,7 +226,7 @@ class StateHandler: return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: List[str] + self, room_id: str, latest_event_ids: Collection[str] ) -> Set[str]: """ Get the users IDs who are currently in a room. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 52efd4a171..9d7a8a792f 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,6 +14,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Callable, @@ -23,7 +24,6 @@ from typing import ( List, Mapping, Optional, - Set, Tuple, ) @@ -527,7 +527,7 @@ class StateStorageController: ) return state_map.get(key) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms @@ -584,7 +584,7 @@ class StateStorageController: async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """ Get the current users in the room with their profiles. If the room is currently partial-stated, this will block until the room has diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 2d6f02c14f..95567826f2 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -240,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get all the client account_data for a user for a room. Args: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5fb152c4ff..484db175d0 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): room_id: str, app_service: "ApplicationService", cache_context: _CacheContext, - ) -> List[str]: + ) -> Sequence[str]: """ Get all users in a room that the appservice controls. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 85c1778a81..1ca66d57d4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -21,6 +21,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: + async def count_devices_by_users( + self, user_ids: Optional[Collection[str]] = None + ) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ def count_devices_by_users_txn( - txn: LoggingTransaction, user_ids: List[str] + txn: LoggingTransaction, user_ids: Collection[str] ) -> int: sql = """ SELECT count(*) @@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cancellable async def get_user_devices_from_cache( self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] - ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: + ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: @@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_ids_not_in_cache = unique_user_ids - user_ids_in_cache # First fetch all the users which all devices are to be returned. - results: Dict[str, Dict[str, JsonDict]] = {} + results: Dict[str, Mapping[str, JsonDict]] = {} for user_id in user_ids: if user_id in user_ids_in_cache: results[user_id] = await self.get_cached_devices_for_user(user_id) # Then fetch all device-specific requests, but skip users we've already # fetched all devices for. + device_specific_results: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_id in user_and_device_ids: if user_id in user_ids_in_cache and user_id not in user_ids: device = await self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device + device_specific_results.setdefault(user_id, {})[device_id] = device + results.update(device_specific_results) set_tag("in_cache", str(results)) set_tag("not_in_cache", str(user_ids_not_in_cache)) @@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return db_to_json(content) @cached() - async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 5903fdaf00..44aa181174 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Sequence, Tuple import attr @@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): ) @cached(max_entries=5000) - async def get_aliases_for_room(self, room_id: str) -> List[str]: + async def get_aliases_for_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c4ac6c33ba..752dc16e17 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -20,7 +20,9 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, Union, cast, @@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( self, user_id: str, device_id: str - ) -> List[str]: + ) -> Sequence[str]: """Returns the fallback key types that have an unused key. Args: @@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # The `Optional` comes from the `@cachedList` decorator. - return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) def _get_bare_e2e_cross_signing_keys_bulk_txn( self, @@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = await self.db_pool.runInteraction( - "get_e2e_cross_signing_signatures", - self._get_e2e_cross_signing_signatures_txn, - result, - from_user_id, + result = cast( + Dict[str, Optional[Mapping[str, JsonDict]]], + await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ), ) return result diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index bbee02ab18..ca780cca36 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -22,6 +22,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, cast, @@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) - async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: + async def get_max_depth_of( + self, event_ids: Collection[str] + ) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: @@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, @@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @cancellable async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @cached(max_entries=5000, num_args=2) async def _get_forward_extremeties_for_room( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index db9a24db5e..4b1061e6d7 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): return await self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) - async def get_monthly_active_count_by_service(self) -> Dict[str, int]: + async def get_monthly_active_count_by_service(self) -> Mapping[str, int]: """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 29972d5204..dddf49c2d5 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -21,7 +21,9 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, cast, ) @@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> Sequence[JsonDict]: """Get receipts for a single room for sending to clients. Args: @@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[JsonDict]: + ) -> Sequence[JsonDict]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: @@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Dict[str, List[JsonDict]]: + ) -> Dict[str, Sequence[JsonDict]]: if not room_ids: return {} @@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 31f0f2bd3d..9a55e17624 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import attr @@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: """Deprecated: use get_userinfo_by_id instead""" def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0018d6f7ab..fa3266c081 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -22,6 +22,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: + ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore): return result is not None @cached() - async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + async def get_aggregation_groups_for_event( + self, event_id: str + ) -> Sequence[JsonDict]: raise NotImplementedError() @cachedList( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ea6a5e2f34..694a5b802c 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -24,6 +24,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self._known_servers_count @cached(max_entries=100000, iterable=True) - async def get_users_in_room(self, room_id: str) -> List[str]: + async def get_users_in_room(self, room_id: str) -> Sequence[str]: """Returns a list of users in the room. Will return inaccurate results for rooms with partial state, since the state for @@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached() - def get_user_in_room_with_profile( - self, room_id: str, user_id: str - ) -> Dict[str, ProfileInfo]: + def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo: raise NotImplementedError() @cachedList( @@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """Get a mapping from user ID to profile information for all users in a given room. The profile information comes directly from this room's `m.room.member` @@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000) - async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: + async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: @@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached() async def get_invited_rooms_for_local_user( self, user_id: str - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Get all the rooms the *local* user is invited to. Args: @@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(iterable=True) - async def get_local_users_in_room(self, room_id: str) -> List[str]: + async def get_local_users_in_room(self, room_id: str) -> Sequence[str]: """ Retrieves a list of the current roommembers who are local to the server. """ @@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user(user_id) - user_who_share_room = set() + user_who_share_room: Set[str] = set() for room_id in room_ids: user_ids = await self.get_users_in_room(room_id) user_who_share_room.update(user_ids) @@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: """Get current hosts in room based on current state.""" # First we check if we already have `get_users_in_room` in the cache, as diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index 05da15074a..5dcb1fc0b5 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, List, Tuple +from typing import Collection, Dict, List, Mapping, Tuple from unpaddedbase64 import encode_base64 @@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList class SignatureWorkerStore(EventsWorkerStore): @cached() - def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: + def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]: # This is a dummy function to allow get_event_reference_hashes # to use its cache raise NotImplementedError() @@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore): ) async def get_event_reference_hashes( self, event_ids: Collection[str] - ) -> Dict[str, Dict[str, bytes]]: + ) -> Mapping[str, Mapping[str, bytes]]: """Get all hashes for given events. Args: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index d5500cdd47..c149a9eacb 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.tcp.streams import AccountDataStream @@ -32,7 +32,9 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): @cached() - async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: + async def get_tags_for_user( + self, user_id: str + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for a user. @@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_updated_tags( self, user_id: str, stream_id: int - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for the rooms where the tags have changed since the given version diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 14ef5b040d..f6a6fd4079 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -16,9 +16,9 @@ import logging import re from typing import ( TYPE_CHECKING, - Dict, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: + async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index a2f347f666..f71ff46d87 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Sequence from twisted.test.proto_helpers import MemoryReactor @@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Check invite and room membership status of a user. Args -- cgit 1.5.1 From 119e0795a58548fb38fab299e7c362fcbb388d68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Feb 2023 14:02:19 -0500 Subject: Implement MSC3966: Add a push rule condition to search for a value in an array. (#15045) The `exact_event_property_contains` condition can be used to search for a value inside of an array. --- changelog.d/15045.feature | 1 + rust/benches/evaluator.rs | 32 +++++++++------- rust/src/push/evaluator.rs | 65 +++++++++++++++++++++++++------- rust/src/push/mod.rs | 33 +++++++++++++++- stubs/synapse/synapse_rust/push.pyi | 7 ++-- synapse/config/experimental.py | 5 +++ synapse/push/bulk_push_rule_evaluator.py | 21 +++++++---- synapse/types/__init__.py | 1 + tests/push/test_push_rule_evaluator.py | 53 ++++++++++++++++++++++++-- 9 files changed, 176 insertions(+), 42 deletions(-) create mode 100644 changelog.d/15045.feature (limited to 'synapse/config') diff --git a/changelog.d/15045.feature b/changelog.d/15045.feature new file mode 100644 index 0000000000..87766befda --- /dev/null +++ b/changelog.d/15045.feature @@ -0,0 +1 @@ +Experimental support for [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966): the `exact_event_property_contains` push rule condition. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 229553ebf8..8213dfd9ea 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -15,8 +15,8 @@ #![feature(test)] use std::collections::BTreeSet; use synapse::push::{ - evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules, - SimpleJsonValue, + evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, JsonValue, + PushRules, SimpleJsonValue, }; use test::Bencher; @@ -27,15 +27,15 @@ fn bench_match_exact(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -54,6 +54,7 @@ fn bench_match_exact(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -76,15 +77,15 @@ fn bench_match_word(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -103,6 +104,7 @@ fn bench_match_word(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -125,15 +127,15 @@ fn bench_match_word_miss(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -152,6 +154,7 @@ fn bench_match_word_miss(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -174,15 +177,15 @@ fn bench_eval_message(b: &mut Bencher) { let flattened_keys = [ ( "type".to_string(), - SimpleJsonValue::Str("m.text".to_string()), + JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())), ), ( "room_id".to_string(), - SimpleJsonValue::Str("!room:server".to_string()), + JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())), ), ( "content.body".to_string(), - SimpleJsonValue::Str("test message".to_string()), + JsonValue::Value(SimpleJsonValue::Str("test message".to_string())), ), ] .into_iter() @@ -201,6 +204,7 @@ fn bench_eval_message(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index dd6b4343ec..2eaa06ad76 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -14,6 +14,7 @@ use std::collections::{BTreeMap, BTreeSet}; +use crate::push::JsonValue; use anyhow::{Context, Error}; use lazy_static::lazy_static; use log::warn; @@ -63,7 +64,7 @@ impl RoomVersionFeatures { pub struct PushRuleEvaluator { /// A mapping of "flattened" keys to simple JSON values in the event, e.g. /// includes things like "type" and "content.msgtype". - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, /// The "content.body", if any. body: String, @@ -87,7 +88,7 @@ pub struct PushRuleEvaluator { /// The related events, indexed by relation type. Flattened in the same manner as /// `flattened_keys`. - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, /// If msc3664, push rules for related events, is enabled. related_event_match_enabled: bool, @@ -101,6 +102,9 @@ pub struct PushRuleEvaluator { /// If MSC3758 (exact_event_match push rule condition) is enabled. msc3758_exact_event_match: bool, + + /// If MSC3966 (exact_event_property_contains push rule condition) is enabled. + msc3966_exact_event_property_contains: bool, } #[pymethods] @@ -109,21 +113,22 @@ impl PushRuleEvaluator { #[allow(clippy::too_many_arguments)] #[new] pub fn py_new( - flattened_keys: BTreeMap, + flattened_keys: BTreeMap, has_mentions: bool, user_mentions: BTreeSet, room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - related_events_flattened: BTreeMap>, + related_events_flattened: BTreeMap>, related_event_match_enabled: bool, room_version_feature_flags: Vec, msc3931_enabled: bool, msc3758_exact_event_match: bool, + msc3966_exact_event_property_contains: bool, ) -> Result { let body = match flattened_keys.get("content.body") { - Some(SimpleJsonValue::Str(s)) => s.clone(), + Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone(), _ => String::new(), }; @@ -141,6 +146,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc3758_exact_event_match, + msc3966_exact_event_property_contains, }) } @@ -263,6 +269,9 @@ impl PushRuleEvaluator { KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } + KnownCondition::ExactEventPropertyContains(exact_event_match) => { + self.match_exact_event_property_contains(exact_event_match)? + } KnownCondition::IsUserMention => { if let Some(uid) = user_id { self.user_mentions.contains(uid) @@ -345,7 +354,7 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(SimpleJsonValue::Str(haystack)) = + let haystack = if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = self.flattened_keys.get(&*event_match.key) { haystack @@ -377,7 +386,9 @@ impl PushRuleEvaluator { let value = &exact_event_match.value; - let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) { + let haystack = if let Some(JsonValue::Value(haystack)) = + self.flattened_keys.get(&*exact_event_match.key) + { haystack } else { return Ok(false); @@ -441,11 +452,12 @@ impl PushRuleEvaluator { return Ok(false); }; - let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) { - haystack - } else { - return Ok(false); - }; + let haystack = + if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = event.get(&**key) { + haystack + } else { + return Ok(false); + }; // For the content.body we match against "words", but for everything // else we match against the entire value. @@ -459,6 +471,29 @@ impl PushRuleEvaluator { compiled_pattern.is_match(haystack) } + /// Evaluates a `exact_event_property_contains` condition. (MSC3758) + fn match_exact_event_property_contains( + &self, + exact_event_match: &ExactEventMatchCondition, + ) -> Result { + // First check if the feature is enabled. + if !self.msc3966_exact_event_property_contains { + return Ok(false); + } + + let value = &exact_event_match.value; + + let haystack = if let Some(JsonValue::Array(haystack)) = + self.flattened_keys.get(&*exact_event_match.key) + { + haystack + } else { + return Ok(false); + }; + + Ok(haystack.contains(&**value)) + } + /// Match the member count against an 'is' condition /// The `is` condition can be things like '>2', '==3' or even just '4'. fn match_member_count(&self, is: &str) -> Result { @@ -488,7 +523,7 @@ fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert( "content.body".to_string(), - SimpleJsonValue::Str("foo bar bob hello".to_string()), + JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())), ); let evaluator = PushRuleEvaluator::py_new( flattened_keys, @@ -503,6 +538,7 @@ fn push_rule_evaluator() { vec![], true, true, + true, ) .unwrap(); @@ -519,7 +555,7 @@ fn test_requires_room_version_supports_condition() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert( "content.body".to_string(), - SimpleJsonValue::Str("foo bar bob hello".to_string()), + JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())), ); let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( @@ -535,6 +571,7 @@ fn test_requires_room_version_supports_condition() { flags, true, true, + true, ) .unwrap(); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 79e519fe11..253b5f367c 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -58,7 +58,7 @@ use anyhow::{Context, Error}; use log::warn; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyLong, PyString}; +use pyo3::types::{PyBool, PyList, PyLong, PyString}; use pythonize::{depythonize, pythonize}; use serde::de::Error as _; use serde::{Deserialize, Serialize}; @@ -280,6 +280,35 @@ impl<'source> FromPyObject<'source> for SimpleJsonValue { } } +/// A JSON values (list, string, int, boolean, or null). +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(untagged)] +pub enum JsonValue { + Array(Vec), + Value(SimpleJsonValue), +} + +impl<'source> FromPyObject<'source> for JsonValue { + fn extract(ob: &'source PyAny) -> PyResult { + if let Ok(l) = ::try_from(ob) { + match l.iter().map(SimpleJsonValue::extract).collect() { + Ok(a) => Ok(JsonValue::Array(a)), + Err(e) => Err(PyTypeError::new_err(format!( + "Can't convert to JsonValue::Array: {}", + e + ))), + } + } else if let Ok(v) = SimpleJsonValue::extract(ob) { + Ok(JsonValue::Value(v)) + } else { + Err(PyTypeError::new_err(format!( + "Can't convert from {} to JsonValue", + ob.get_type().name()? + ))) + } + } +} + /// A condition used in push rules to match against an event. /// /// We need this split as `serde` doesn't give us the ability to have a @@ -303,6 +332,8 @@ pub enum KnownCondition { ExactEventMatch(ExactEventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), + #[serde(rename = "org.matrix.msc3966.exact_event_property_contains")] + ExactEventPropertyContains(ExactEventMatchCondition), #[serde(rename = "org.matrix.msc3952.is_user_mention")] IsUserMention, #[serde(rename = "org.matrix.msc3952.is_room_mention")] diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 328f681a29..7b33c30cc9 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -14,7 +14,7 @@ from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union -from synapse.types import JsonDict, SimpleJsonValue +from synapse.types import JsonDict, JsonValue class PushRule: @property @@ -56,18 +56,19 @@ def get_base_rule_ids() -> Collection[str]: ... class PushRuleEvaluator: def __init__( self, - flattened_keys: Mapping[str, SimpleJsonValue], + flattened_keys: Mapping[str, JsonValue], has_mentions: bool, user_mentions: Set[str], room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - related_events_flattened: Mapping[str, Mapping[str, SimpleJsonValue]], + related_events_flattened: Mapping[str, Mapping[str, JsonValue]], related_event_match_enabled: bool, room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, msc3758_exact_event_match: bool, + msc3966_exact_event_property_contains: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 6ac2f0c10d..1d294f8798 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -188,3 +188,8 @@ class ExperimentalConfig(Config): self.msc3958_supress_edit_notifs = experimental.get( "msc3958_supress_edit_notifs", False ) + + # MSC3966: exact_event_property_contains push rule condition. + self.msc3966_exact_event_property_contains = experimental.get( + "msc3966_exact_event_property_contains", False + ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f6a5bffb0f..2e917c90c4 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -44,7 +44,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator -from synapse.types import SimpleJsonValue +from synapse.types import JsonValue from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func @@ -259,13 +259,13 @@ class BulkPushRuleEvaluator: async def _related_events( self, event: EventBase - ) -> Dict[str, Dict[str, SimpleJsonValue]]: + ) -> Dict[str, Dict[str, JsonValue]]: """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation Returns: Mapping of relation type to flattened events. """ - related_events: Dict[str, Dict[str, SimpleJsonValue]] = {} + related_events: Dict[str, Dict[str, JsonValue]] = {} if self._related_event_match_enabled: related_event_id = event.content.get("m.relates_to", {}).get("event_id") relation_type = event.content.get("m.relates_to", {}).get("rel_type") @@ -429,6 +429,7 @@ class BulkPushRuleEvaluator: event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc3758_exact_event_match, + self.hs.config.experimental.msc3966_exact_event_property_contains, ) users = rules_by_user.keys() @@ -502,18 +503,22 @@ RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] +def _is_simple_value(value: Any) -> bool: + return isinstance(value, (bool, str)) or type(value) is int or value is None + + def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, - result: Optional[Dict[str, SimpleJsonValue]] = None, + result: Optional[Dict[str, JsonValue]] = None, *, msc3783_escape_event_match_key: bool = False, -) -> Dict[str, SimpleJsonValue]: +) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, flatten it into a single layer dictionary by combining the keys & sub-keys. - String, integer, boolean, and null values are kept. All others are dropped. + String, integer, boolean, null or lists of those values are kept. All others are dropped. Transforms: @@ -542,8 +547,10 @@ def _flatten_dict( # nested fields. key = key.replace("\\", "\\\\").replace(".", "\\.") - if isinstance(value, (bool, str)) or type(value) is int or value is None: + if _is_simple_value(value): result[".".join(prefix + [key])] = value + elif isinstance(value, (list, tuple)): + result[".".join(prefix + [key])] = [v for v in value if _is_simple_value(v)] elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below _flatten_dict( diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 52e366c8ae..33363867c4 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -71,6 +71,7 @@ MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. # A "simple" (canonical) JSON value. SimpleJsonValue = Optional[Union[str, int, bool]] +JsonValue = Union[List[SimpleJsonValue], Tuple[SimpleJsonValue, ...], SimpleJsonValue] # A JSON-serialisable dict. JsonDict = Dict[str, Any] # A JSON-serialisable mapping; roughly speaking an immutable JSONDict. diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 6603447341..0554d247bc 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import PushRuleEvaluator from synapse.types import JsonDict, JsonMapping, UserID from synapse.util import Clock +from synapse.util.frozenutils import freeze from tests import unittest from tests.test_utils.event_injection import create_event, inject_member_event @@ -57,17 +58,24 @@ class FlattenDictTestCase(unittest.TestCase): ) def test_non_string(self) -> None: - """Booleans, ints, and nulls should be kept while other items are dropped.""" + """String, booleans, ints, nulls and list of those should be kept while other items are dropped.""" input: Dict[str, Any] = { "woo": "woo", "foo": True, "bar": 1, "baz": None, - "fuzz": [], + "fuzz": ["woo", True, 1, None, [], {}], "boo": {}, } self.assertEqual( - {"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input) + { + "woo": "woo", + "foo": True, + "bar": 1, + "baz": None, + "fuzz": ["woo", True, 1, None], + }, + _flatten_dict(input), ) def test_event(self) -> None: @@ -117,6 +125,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -128,6 +137,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -169,6 +179,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, msc3758_exact_event_match=True, + msc3966_exact_event_property_contains=True, ) def test_display_name(self) -> None: @@ -549,6 +560,42 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "incorrect types should not match", ) + def test_exact_event_property_contains(self) -> None: + """Check that exact_event_property_contains conditions work as expected.""" + + condition = { + "kind": "org.matrix.msc3966.exact_event_property_contains", + "key": "content.value", + "value": "foobaz", + } + self._assert_matches( + condition, + {"value": ["foobaz"]}, + "exact value should match", + ) + self._assert_matches( + condition, + {"value": ["foobaz", "bugz"]}, + "extra values should match", + ) + self._assert_not_matches( + condition, + {"value": ["FoobaZ"]}, + "values should match and be case-sensitive", + ) + self._assert_not_matches( + condition, + {"value": "foobaz"}, + "does not search in a string", + ) + + # it should work on frozendicts too + self._assert_matches( + condition, + freeze({"value": ["foobaz"]}), + "values should match on frozendicts", + ) + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) -- cgit 1.5.1 From 27a3a72a50cb24f25e48fad1e6e79aba2cd1bea2 Mon Sep 17 00:00:00 2001 From: 999lakhisidhu <42063995+999lakhisidhu@users.noreply.github.com> Date: Wed, 15 Feb 2023 16:39:31 +0400 Subject: Support for selecting the Redis logical database. (#15034) Note that this is only used for key-value store (cached values) and not for the pub/sub replication used by Synapse. --- changelog.d/15034.feature | 1 + contrib/docker_compose_workers/README.md | 1 + docs/usage/configuration/config_documentation.md | 4 ++++ synapse/config/redis.py | 1 + synapse/server.py | 1 + 5 files changed, 8 insertions(+) create mode 100644 changelog.d/15034.feature (limited to 'synapse/config') diff --git a/changelog.d/15034.feature b/changelog.d/15034.feature new file mode 100644 index 0000000000..34f320da92 --- /dev/null +++ b/changelog.d/15034.feature @@ -0,0 +1 @@ +Allow Synapse to use a specific Redis [logical database](https://redis.io/commands/select/) in worker-mode deployments. diff --git a/contrib/docker_compose_workers/README.md b/contrib/docker_compose_workers/README.md index bdd3dd32e0..d3cdfe5614 100644 --- a/contrib/docker_compose_workers/README.md +++ b/contrib/docker_compose_workers/README.md @@ -68,6 +68,7 @@ redis: enabled: true host: redis port: 6379 + # dbid: # password: ``` diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 2883f76a26..75483bfb12 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3927,6 +3927,9 @@ This setting has the following sub-options: * `host` and `port`: Optional host and port to use to connect to redis. Defaults to localhost and 6379 * `password`: Optional password if configured on the Redis instance. +* `dbid`: Optional redis dbid if needs to connect to specific redis logical db. + + _Added in Synapse 1.78.0._ Example configuration: ```yaml @@ -3935,6 +3938,7 @@ redis: host: localhost port: 6379 password: + dbid: ``` --- ## Individual worker configuration diff --git a/synapse/config/redis.py b/synapse/config/redis.py index b42dd2e93a..e6a75be434 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -33,4 +33,5 @@ class RedisConfig(Config): self.redis_host = redis_config.get("host", "localhost") self.redis_port = redis_config.get("port", 6379) + self.redis_dbid = redis_config.get("dbid", None) self.redis_password = redis_config.get("password") diff --git a/synapse/server.py b/synapse/server.py index efc6b5f895..e5a3475247 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -827,6 +827,7 @@ class HomeServer(metaclass=abc.ABCMeta): hs=self, host=self.config.redis.redis_host, port=self.config.redis.redis_port, + dbid=self.config.redis.redis_dbid, password=self.config.redis.redis_password, reconnect=True, ) -- cgit 1.5.1 From 979f237b282cbdaab8d74cc4c7473117093d63d9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 16 Feb 2023 09:51:22 -0500 Subject: Update intentional mentions (MSC3952) to depend on `exact_event_match` (MSC3758). (#15037) This replaces the specific `is_room_mention` push rule condition used in MSC3952 with the generic `exact_event_match` push rule condition from MSC3758. No functionality changes due to this. --- changelog.d/15037.misc | 1 + rust/benches/evaluator.rs | 4 ---- rust/src/push/base_rules.rs | 7 +++++-- rust/src/push/evaluator.rs | 7 ------- rust/src/push/mod.rs | 13 ------------- stubs/synapse/synapse_rust/push.pyi | 1 - synapse/config/experimental.py | 7 ++++--- synapse/push/bulk_push_rule_evaluator.py | 4 ---- tests/push/test_bulk_push_rule_evaluator.py | 18 ++++++++++++++++-- tests/push/test_push_rule_evaluator.py | 23 ----------------------- 10 files changed, 26 insertions(+), 59 deletions(-) create mode 100644 changelog.d/15037.misc (limited to 'synapse/config') diff --git a/changelog.d/15037.misc b/changelog.d/15037.misc new file mode 100644 index 0000000000..fabfe77d35 --- /dev/null +++ b/changelog.d/15037.misc @@ -0,0 +1 @@ +Update [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952) support based on changes to the MSC. diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 8213dfd9ea..efd19a2165 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -45,7 +45,6 @@ fn bench_match_exact(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -95,7 +94,6 @@ fn bench_match_word(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -145,7 +143,6 @@ fn bench_match_word_miss(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), @@ -195,7 +192,6 @@ fn bench_eval_message(b: &mut Bencher) { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), Default::default(), diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index dcbca340fe..4a62b9696f 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -21,13 +21,13 @@ use lazy_static::lazy_static; use serde_json::Value; use super::KnownCondition; -use crate::push::Action; use crate::push::Condition; use crate::push::EventMatchCondition; use crate::push::PushRule; use crate::push::RelatedEventMatchCondition; use crate::push::SetTweak; use crate::push::TweakValue; +use crate::push::{Action, ExactEventMatchCondition, SimpleJsonValue}; const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak { set_tweak: Cow::Borrowed("highlight"), @@ -168,7 +168,10 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mention"), priority_class: 5, conditions: Cow::Borrowed(&[ - Condition::Known(KnownCondition::IsRoomMention), + Condition::Known(KnownCondition::ExactEventMatch(ExactEventMatchCondition { + key: Cow::Borrowed("content.org.matrix.msc3952.mentions.room"), + value: Cow::Borrowed(&SimpleJsonValue::Bool(true)), + })), Condition::Known(KnownCondition::SenderNotificationPermission { key: Cow::Borrowed("room"), }), diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index 2eaa06ad76..55551ecb56 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -73,8 +73,6 @@ pub struct PushRuleEvaluator { has_mentions: bool, /// The user mentions that were part of the message. user_mentions: BTreeSet, - /// True if the message is a room message. - room_mention: bool, /// The number of users in the room. room_member_count: u64, @@ -116,7 +114,6 @@ impl PushRuleEvaluator { flattened_keys: BTreeMap, has_mentions: bool, user_mentions: BTreeSet, - room_mention: bool, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, @@ -137,7 +134,6 @@ impl PushRuleEvaluator { body, has_mentions, user_mentions, - room_mention, room_member_count, notification_power_levels, sender_power_level, @@ -279,7 +275,6 @@ impl PushRuleEvaluator { false } } - KnownCondition::IsRoomMention => self.room_mention, KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -529,7 +524,6 @@ fn push_rule_evaluator() { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), BTreeMap::new(), @@ -562,7 +556,6 @@ fn test_requires_room_version_supports_condition() { flattened_keys, false, BTreeSet::new(), - false, 10, Some(0), BTreeMap::new(), diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 253b5f367c..fdd2b2c143 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -336,8 +336,6 @@ pub enum KnownCondition { ExactEventPropertyContains(ExactEventMatchCondition), #[serde(rename = "org.matrix.msc3952.is_user_mention")] IsUserMention, - #[serde(rename = "org.matrix.msc3952.is_room_mention")] - IsRoomMention, ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -667,17 +665,6 @@ fn test_deserialize_unstable_msc3952_user_condition() { )); } -#[test] -fn test_deserialize_unstable_msc3952_room_condition() { - let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; - - let condition: Condition = serde_json::from_str(json).unwrap(); - assert!(matches!( - condition, - Condition::Known(KnownCondition::IsRoomMention) - )); -} - #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 7b33c30cc9..a8f0ed2435 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -59,7 +59,6 @@ class PushRuleEvaluator: flattened_keys: Mapping[str, JsonValue], has_mentions: bool, user_mentions: Set[str], - room_mention: bool, room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1d294f8798..54c91953e1 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -179,9 +179,10 @@ class ExperimentalConfig(Config): "msc3783_escape_event_match_key", False ) - # MSC3952: Intentional mentions - self.msc3952_intentional_mentions = experimental.get( - "msc3952_intentional_mentions", False + # MSC3952: Intentional mentions, this depends on MSC3758. + self.msc3952_intentional_mentions = ( + experimental.get("msc3952_intentional_mentions", False) + and self.msc3758_exact_event_match ) # MSC3959: Do not generate notifications for edits. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 2e917c90c4..5fc38431ba 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -400,7 +400,6 @@ class BulkPushRuleEvaluator: mentions = event.content.get(EventContentFields.MSC3952_MENTIONS) has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict) user_mentions: Set[str] = set() - room_mention = False if has_mentions: # mypy seems to have lost the type even though it must be a dict here. assert isinstance(mentions, dict) @@ -410,8 +409,6 @@ class BulkPushRuleEvaluator: user_mentions = set( filter(lambda item: isinstance(item, str), user_mentions_raw) ) - # Room mention is only true if the value is exactly true. - room_mention = mentions.get("room") is True evaluator = PushRuleEvaluator( _flatten_dict( @@ -420,7 +417,6 @@ class BulkPushRuleEvaluator: ), has_mentions, user_mentions, - room_mention, room_member_count, sender_power_level, notification_levels, diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 7567756135..199e3d7b70 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -227,7 +227,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) return len(result) > 0 - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3758_exact_event_match": True, + "msc3952_intentional_mentions": True, + } + } + ) def test_user_mentions(self) -> None: """Test the behavior of an event which includes invalid user mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -323,7 +330,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) ) - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3758_exact_event_match": True, + "msc3952_intentional_mentions": True, + } + } + ) def test_room_mentions(self) -> None: """Test the behavior of an event which includes invalid room mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 0554d247bc..d320a12f96 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -149,7 +149,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): *, has_mentions: bool = False, user_mentions: Optional[Set[str]] = None, - room_mention: bool = False, related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( @@ -170,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): _flatten_dict(event), has_mentions, user_mentions or set(), - room_mention, room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), @@ -232,27 +230,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions # since the BulkPushRuleEvaluator is what handles data sanitisation. - def test_room_mentions(self) -> None: - """Check for room mentions.""" - condition = {"kind": "org.matrix.msc3952.is_room_mention"} - - # No room mention shouldn't match. - evaluator = self._get_evaluator({}, has_mentions=True) - self.assertFalse(evaluator.matches(condition, None, None)) - - # Room mention should match. - evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True) - self.assertTrue(evaluator.matches(condition, None, None)) - - # A room mention and user mention is valid. - evaluator = self._get_evaluator( - {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True - ) - self.assertTrue(evaluator.matches(condition, None, None)) - - # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions - # since the BulkPushRuleEvaluator is what handles data sanitisation. - def _assert_matches( self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: -- cgit 1.5.1 From 490a3675bd7225b5695e505fea225d7c30127551 Mon Sep 17 00:00:00 2001 From: realtyem Date: Mon, 20 Feb 2023 06:23:00 -0600 Subject: Allow health listener resource to load (#15096) * Allow health listener resource to load. * changelog * Update changelog.d/15096.bugfix --- changelog.d/15096.bugfix | 1 + synapse/config/server.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/15096.bugfix (limited to 'synapse/config') diff --git a/changelog.d/15096.bugfix b/changelog.d/15096.bugfix new file mode 100644 index 0000000000..09b4d861f8 --- /dev/null +++ b/changelog.d/15096.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.76 where workers would fail to start if the `health` listener was configured. diff --git a/synapse/config/server.py b/synapse/config/server.py index ecdaa2d9dd..d4ef9930b0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -177,6 +177,7 @@ KNOWN_RESOURCES = { "client", "consent", "federation", + "health", "keys", "media", "metrics", -- cgit 1.5.1 From 4ed08ff72ef8f1abf85ab22de1e51b570f67b27e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Feb 2023 14:37:18 -0500 Subject: Tighten the default rate limit of creating new devices. (#15135) --- changelog.d/15135.misc | 1 + docs/usage/configuration/config_documentation.md | 6 +++--- synapse/config/ratelimiting.py | 13 +++++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 changelog.d/15135.misc (limited to 'synapse/config') diff --git a/changelog.d/15135.misc b/changelog.d/15135.misc new file mode 100644 index 0000000000..25c4dbffe1 --- /dev/null +++ b/changelog.d/15135.misc @@ -0,0 +1 @@ +Tighten the login ratelimit defaults. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 58c6955689..ab1f9f4963 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1518,11 +1518,11 @@ rc_registration_token_validity: This option specifies several limits for login: * `address` ratelimits login requests based on the client's IP - address. Defaults to `per_second: 0.17`, `burst_count: 3`. + address. Defaults to `per_second: 0.003`, `burst_count: 5`. * `account` ratelimits login requests based on the account the - client is attempting to log into. Defaults to `per_second: 0.17`, - `burst_count: 3`. + client is attempting to log into. Defaults to `per_second: 0.03`, + `burst_count: 5`. * `failed_attempts` ratelimits login requests based on the account the client is attempting to log into, based on the amount of failed login diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 5c13fe428a..b733fac617 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -87,9 +87,18 @@ class RatelimitConfig(Config): defaults={"per_second": 0.1, "burst_count": 5}, ) + # It is reasonable to login with a bunch of devices at once (i.e. when + # setting up an account), but it is *not* valid to continually be + # logging into new devices. rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings(rc_login_config.get("address", {})) - self.rc_login_account = RatelimitSettings(rc_login_config.get("account", {})) + self.rc_login_address = RatelimitSettings( + rc_login_config.get("address", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + self.rc_login_account = RatelimitSettings( + rc_login_config.get("account", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) self.rc_login_failed_attempts = RatelimitSettings( rc_login_config.get("failed_attempts", {}) ) -- cgit 1.5.1 From 9bb2eac71962970d02842bca441f4bcdbbf93a11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Feb 2023 15:29:09 -0500 Subject: Bump black from 22.12.0 to 23.1.0 (#15103) --- changelog.d/15103.misc | 1 + poetry.lock | 42 ++++++++++++++-------- stubs/sortedcontainers/sortedlist.pyi | 1 - synapse/_scripts/register_new_matrix_user.py | 2 -- synapse/_scripts/synapse_port_db.py | 1 - synapse/_scripts/synctl.py | 1 - synapse/app/_base.py | 2 +- synapse/app/complement_fork_starter.py | 2 +- synapse/app/generic_worker.py | 1 - synapse/app/homeserver.py | 1 - synapse/config/consent.py | 1 - synapse/config/database.py | 1 - synapse/config/homeserver.py | 1 - synapse/config/ratelimiting.py | 1 - synapse/config/repository.py | 1 - synapse/config/server.py | 1 - synapse/config/tls.py | 1 - synapse/crypto/keyring.py | 2 +- synapse/events/third_party_rules.py | 2 -- synapse/federation/send_queue.py | 4 +-- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 2 -- synapse/handlers/directory.py | 8 +++-- synapse/handlers/e2e_room_keys.py | 1 - synapse/handlers/event_auth.py | 1 - synapse/handlers/initial_sync.py | 1 - synapse/handlers/presence.py | 2 -- synapse/handlers/room.py | 8 +++-- synapse/handlers/room_batch.py | 2 +- synapse/handlers/sync.py | 1 - synapse/logging/opentracing.py | 1 + synapse/metrics/__init__.py | 1 - synapse/metrics/_gc.py | 1 - synapse/push/bulk_push_rule_evaluator.py | 1 - synapse/replication/http/account_data.py | 1 - synapse/replication/http/devices.py | 1 - synapse/replication/tcp/redis.py | 1 - synapse/replication/tcp/streams/events.py | 1 - synapse/rest/admin/rooms.py | 4 --- synapse/rest/admin/users.py | 8 +++-- synapse/rest/client/auth.py | 1 - synapse/rest/client/filter.py | 1 - synapse/rest/client/register.py | 18 ++++++---- synapse/rest/media/v1/_base.py | 1 - synapse/rest/media/v1/thumbnailer.py | 1 - synapse/storage/databases/main/deviceinbox.py | 5 ++- synapse/storage/databases/main/devices.py | 4 +-- synapse/storage/databases/main/e2e_room_keys.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 8 ++--- synapse/storage/databases/main/event_federation.py | 1 - synapse/storage/databases/main/events.py | 1 - .../storage/databases/main/events_bg_updates.py | 4 +-- synapse/storage/databases/main/events_worker.py | 2 +- synapse/storage/databases/main/media_repository.py | 1 - synapse/storage/databases/main/pusher.py | 3 -- synapse/storage/databases/main/receipts.py | 1 - synapse/storage/databases/main/room.py | 1 - synapse/storage/databases/main/search.py | 2 -- synapse/storage/databases/main/state.py | 1 - synapse/storage/databases/main/stats.py | 2 +- synapse/storage/databases/main/stream.py | 1 + synapse/storage/databases/main/transactions.py | 1 - synapse/storage/databases/main/user_directory.py | 1 - synapse/storage/databases/state/bg_updates.py | 1 - synapse/storage/databases/state/store.py | 7 ++-- synapse/storage/prepare_database.py | 4 +-- synapse/types/state.py | 2 +- synapse/util/caches/__init__.py | 1 - synapse/util/check_dependencies.py | 2 +- synapse/util/patch_inline_callbacks.py | 1 - synmark/__main__.py | 2 -- synmark/suites/logging.py | 1 - tests/federation/test_complexity.py | 4 --- tests/federation/test_federation_server.py | 1 - tests/handlers/test_sso.py | 1 - tests/handlers/test_stats.py | 1 - tests/http/federation/test_srv_resolver.py | 1 - tests/http/test_client.py | 2 +- tests/push/test_bulk_push_rule_evaluator.py | 1 - tests/push/test_email.py | 2 -- tests/replication/slave/storage/test_events.py | 1 - tests/rest/admin/test_device.py | 3 -- tests/rest/admin/test_media.py | 5 --- tests/rest/admin/test_room.py | 1 - tests/rest/admin/test_server_notice.py | 1 - tests/rest/client/test_account.py | 4 --- tests/rest/client/test_auth.py | 2 -- tests/rest/client/test_capabilities.py | 1 - tests/rest/client/test_consent.py | 1 - tests/rest/client/test_directory.py | 1 - tests/rest/client/test_ephemeral_message.py | 1 - tests/rest/client/test_events.py | 3 -- tests/rest/client/test_filter.py | 1 - tests/rest/client/test_login.py | 2 -- tests/rest/client/test_login_token_request.py | 1 - tests/rest/client/test_presence.py | 1 - tests/rest/client/test_profile.py | 3 -- tests/rest/client/test_register.py | 4 --- tests/rest/client/test_rendezvous.py | 1 - tests/rest/client/test_rooms.py | 14 ++------ tests/rest/client/test_sync.py | 3 -- tests/rest/client/test_third_party_rules.py | 3 ++ tests/rest/media/test_media_retention.py | 1 - tests/rest/media/v1/test_media_storage.py | 3 -- tests/rest/media/v1/test_url_preview.py | 3 -- tests/server_notices/test_consent.py | 2 -- tests/storage/databases/main/test_deviceinbox.py | 1 - tests/storage/databases/main/test_receipts.py | 2 +- tests/storage/databases/main/test_room.py | 1 - tests/storage/test_client_ips.py | 1 - tests/storage/test_event_chain.py | 2 -- tests/storage/test_event_federation.py | 2 +- tests/storage/test_event_push_actions.py | 2 +- tests/storage/test_purge.py | 1 - tests/storage/test_roommember.py | 3 -- tests/storage/test_state.py | 30 ++++++++-------- tests/test_mau.py | 1 - 117 files changed, 108 insertions(+), 218 deletions(-) create mode 100644 changelog.d/15103.misc (limited to 'synapse/config') diff --git a/changelog.d/15103.misc b/changelog.d/15103.misc new file mode 100644 index 0000000000..65322498c9 --- /dev/null +++ b/changelog.d/15103.misc @@ -0,0 +1 @@ +Bump black from 22.12.0 to 23.1.0. diff --git a/poetry.lock b/poetry.lock index 4d724ab782..8ffdab7a22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -90,32 +90,46 @@ typecheck = ["mypy"] [[package]] name = "black" -version = "22.12.0" +version = "23.1.0" description = "The uncompromising code formatter." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, - {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, - {file = "black-22.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d30b212bffeb1e252b31dd269dfae69dd17e06d92b87ad26e23890f3efea366f"}, - {file = "black-22.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:7412e75863aa5c5411886804678b7d083c7c28421210180d67dfd8cf1221e1f4"}, - {file = "black-22.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c116eed0efb9ff870ded8b62fe9f28dd61ef6e9ddd28d83d7d264a38417dcee2"}, - {file = "black-22.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:1f58cbe16dfe8c12b7434e50ff889fa479072096d79f0a7f25e4ab8e94cd8350"}, - {file = "black-22.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77d86c9f3db9b1bf6761244bc0b3572a546f5fe37917a044e02f3166d5aafa7d"}, - {file = "black-22.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:82d9fe8fee3401e02e79767016b4907820a7dc28d70d137eb397b92ef3cc5bfc"}, - {file = "black-22.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101c69b23df9b44247bd88e1d7e90154336ac4992502d4197bdac35dd7ee3320"}, - {file = "black-22.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:559c7a1ba9a006226f09e4916060982fd27334ae1998e7a38b3f33a37f7a2148"}, - {file = "black-22.12.0-py3-none-any.whl", hash = "sha256:436cc9167dd28040ad90d3b404aec22cedf24a6e4d7de221bec2730ec0c97bcf"}, - {file = "black-22.12.0.tar.gz", hash = "sha256:229351e5a18ca30f447bf724d007f890f97e13af070bb6ad4c0a441cd7596a2f"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:b6a92a41ee34b883b359998f0c8e6eb8e99803aa8bf3123bf2b2e6fec505a221"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:57c18c5165c1dbe291d5306e53fb3988122890e57bd9b3dcb75f967f13411a26"}, + {file = "black-23.1.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:9880d7d419bb7e709b37e28deb5e68a49227713b623c72b2b931028ea65f619b"}, + {file = "black-23.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6663f91b6feca5d06f2ccd49a10f254f9298cc1f7f49c46e498a0771b507104"}, + {file = "black-23.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9afd3f493666a0cd8f8df9a0200c6359ac53940cbde049dcb1a7eb6ee2dd7074"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:bfffba28dc52a58f04492181392ee380e95262af14ee01d4bc7bb1b1c6ca8d27"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:c1c476bc7b7d021321e7d93dc2cbd78ce103b84d5a4cf97ed535fbc0d6660648"}, + {file = "black-23.1.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:382998821f58e5c8238d3166c492139573325287820963d2f7de4d518bd76958"}, + {file = "black-23.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bf649fda611c8550ca9d7592b69f0637218c2369b7744694c5e4902873b2f3a"}, + {file = "black-23.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:121ca7f10b4a01fd99951234abdbd97728e1240be89fde18480ffac16503d481"}, + {file = "black-23.1.0-cp37-cp37m-macosx_10_16_x86_64.whl", hash = "sha256:a8471939da5e824b891b25751955be52ee7f8a30a916d570a5ba8e0f2eb2ecad"}, + {file = "black-23.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8178318cb74f98bc571eef19068f6ab5613b3e59d4f47771582f04e175570ed8"}, + {file = "black-23.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a436e7881d33acaf2536c46a454bb964a50eff59b21b51c6ccf5a40601fbef24"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:a59db0a2094d2259c554676403fa2fac3473ccf1354c1c63eccf7ae65aac8ab6"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:0052dba51dec07ed029ed61b18183942043e00008ec65d5028814afaab9a22fd"}, + {file = "black-23.1.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:49f7b39e30f326a34b5c9a4213213a6b221d7ae9d58ec70df1c4a307cf2a1580"}, + {file = "black-23.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:162e37d49e93bd6eb6f1afc3e17a3d23a823042530c37c3c42eeeaf026f38468"}, + {file = "black-23.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b70eb40a78dfac24842458476135f9b99ab952dd3f2dab738c1881a9b38b753"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:a29650759a6a0944e7cca036674655c2f0f63806ddecc45ed40b7b8aa314b651"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:bb460c8561c8c1bec7824ecbc3ce085eb50005883a6203dcfb0122e95797ee06"}, + {file = "black-23.1.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c91dfc2c2a4e50df0026f88d2215e166616e0c80e86004d0003ece0488db2739"}, + {file = "black-23.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a951cc83ab535d248c89f300eccbd625e80ab880fbcfb5ac8afb5f01a258ac9"}, + {file = "black-23.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:0680d4380db3719ebcfb2613f34e86c8e6d15ffeabcf8ec59355c5e7b85bb555"}, + {file = "black-23.1.0-py3-none-any.whl", hash = "sha256:7a0f701d314cfa0896b9001df70a530eb2472babb76086344e688829efd97d32"}, + {file = "black-23.1.0.tar.gz", hash = "sha256:b0bd97bea8903f5a2ba7219257a44e3f1f9d00073d6cc1add68f0beec69692ac"}, ] [package.dependencies] click = ">=8.0.0" mypy-extensions = ">=0.4.3" +packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi index 1fe1a136f1..0e745c0a79 100644 --- a/stubs/sortedcontainers/sortedlist.pyi +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -29,7 +29,6 @@ _Repr = Callable[[], str] def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ... class SortedList(MutableSequence[_T]): - DEFAULT_LOAD_FACTOR: int = ... def __init__( self, diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py index 2b74a40166..19ca399d44 100644 --- a/synapse/_scripts/register_new_matrix_user.py +++ b/synapse/_scripts/register_new_matrix_user.py @@ -47,7 +47,6 @@ def request_registration( _print: Callable[[str], None] = print, exit: Callable[[int], None] = sys.exit, ) -> None: - url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) # Get the nonce @@ -154,7 +153,6 @@ def register_new_user( def main() -> None: - logging.captureWarnings(True) parser = argparse.ArgumentParser( diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 0d35e0af8f..2c9cbf8b27 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -1205,7 +1205,6 @@ class CursesProgress(Progress): if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) else: - if self.total_processed > 0: left = float(self.total_remaining) / self.total_processed diff --git a/synapse/_scripts/synctl.py b/synapse/_scripts/synctl.py index b4c96ad7f3..077b90935e 100755 --- a/synapse/_scripts/synctl.py +++ b/synapse/_scripts/synctl.py @@ -167,7 +167,6 @@ Worker = collections.namedtuple( def main() -> None: - parser = argparse.ArgumentParser() parser.add_argument( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index a5aa2185a2..28062dd69d 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -213,7 +213,7 @@ def handle_startup_exception(e: Exception) -> NoReturn: def redirect_stdio_to_logs() -> None: streams = [("stdout", LogLevel.info), ("stderr", LogLevel.error)] - for (stream, level) in streams: + for stream, level in streams: oldStream = getattr(sys, stream) loggingFile = LoggingFile( logger=twisted.logger.Logger(namespace=stream), diff --git a/synapse/app/complement_fork_starter.py b/synapse/app/complement_fork_starter.py index 920538f44d..c8dc3f9d76 100644 --- a/synapse/app/complement_fork_starter.py +++ b/synapse/app/complement_fork_starter.py @@ -219,7 +219,7 @@ def main() -> None: # memory space and don't need to repeat the work of loading the code! # Instead of using fork() directly, we use the multiprocessing library, # which uses fork() on Unix platforms. - for (func, worker_args) in zip(worker_functions, args_by_worker): + for func, worker_args in zip(worker_functions, args_by_worker): process = multiprocessing.Process( target=_worker_entrypoint, args=(func, proxy_reactor, worker_args) ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 946f3a3807..0dec24369a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -157,7 +157,6 @@ class GenericWorkerServer(HomeServer): DATASTORE_CLASS = GenericWorkerSlavedStore # type: ignore def _listen_http(self, listener_config: ListenerConfig) -> None: - assert listener_config.http_options is not None # We always include a health resource. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6176a70eb2..b8830b1a9c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -321,7 +321,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer: and not config.registration.registrations_require_3pid and not config.registration.registration_requires_token ): - raise ConfigError( "You have enabled open registration without any verification. This is a known vector for " "spam and abuse. If you would like to allow public registration, please consider adding email, " diff --git a/synapse/config/consent.py b/synapse/config/consent.py index be74609dc4..5bfd0cbb71 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py @@ -22,7 +22,6 @@ from ._base import Config class ConsentConfig(Config): - section = "consent" def __init__(self, *args: Any): diff --git a/synapse/config/database.py b/synapse/config/database.py index 928fec8dfe..596d8769fe 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -154,7 +154,6 @@ class DatabaseConfig(Config): logger.warning(NON_SQLITE_DATABASE_PATH_WARNING) def set_databasepath(self, database_path: str) -> None: - if database_path != ":memory:": database_path = self.abspath(database_path) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4d2b298a70..c205a78039 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -56,7 +56,6 @@ from .workers import WorkerConfig class HomeServerConfig(RootConfig): - config_classes = [ ModulesConfig, ServerConfig, diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index b733fac617..a5514e70a2 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -46,7 +46,6 @@ class RatelimitConfig(Config): section = "ratelimiting" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: diff --git a/synapse/config/repository.py b/synapse/config/repository.py index e4759711ed..2da40c09f0 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -116,7 +116,6 @@ class ContentRepositoryConfig(Config): section = "media" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( diff --git a/synapse/config/server.py b/synapse/config/server.py index d4ef9930b0..0e46b849cf 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -735,7 +735,6 @@ class ServerConfig(Config): listeners: Optional[List[dict]], **kwargs: Any, ) -> str: - _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 336fe3e0da..318270ebb8 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -30,7 +30,6 @@ class TlsConfig(Config): section = "tls" def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.tls_certificate_file = self.abspath(config.get("tls_certificate_path")) self.tls_private_key_file = self.abspath(config.get("tls_private_key_path")) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 86cd4af9bd..d710607c63 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -399,7 +399,7 @@ class Keyring: # We now convert the returned list of results into a map from server # name to key ID to FetchKeyResult, to return. to_return: Dict[str, Dict[str, FetchKeyResult]] = {} - for (request, results) in zip(deduped_requests, results_per_request): + for request, results in zip(deduped_requests, results_per_request): to_return_by_server = to_return.setdefault(request.server_name, {}) for key_id, key_result in results.items(): existing = to_return_by_server.get(key_id) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 97c61cc258..9a25ed419b 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -78,7 +78,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: # correctly, we need to await its result. Therefore it doesn't make a lot of # sense to make it go through the run() wrapper. if f.__name__ == "check_event_allowed": - # We need to wrap check_event_allowed because its old form would return either # a boolean or a dict, but now we want to return the dict separately from the # boolean. @@ -100,7 +99,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: return wrap_check_event_allowed if f.__name__ == "on_create_room": - # We need to wrap on_create_room because its old form would return a boolean # if the room creation is denied, but now we just want it to raise an # exception. diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index d720b5fd3f..3063df7990 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -314,7 +314,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): # stream position. keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} - for ((destination, edu_key), pos) in keyed_edus.items(): + for (destination, edu_key), pos in keyed_edus.items(): rows.append( ( pos, @@ -329,7 +329,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): j = self.edus.bisect_right(to_token) + 1 edus = self.edus.items()[i:j] - for (pos, edu) in edus: + for pos, edu in edus: rows.append((pos, EduRow(edu))) # Sort rows based on pos diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5d1d21cdc8..ec3ab968e9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -737,7 +737,7 @@ class ApplicationServicesHandler: ) ret = [] - for (success, result) in results: + for success, result in results: if success: ret.extend(result) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index cf12b55d21..b12bc4c9a3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -815,7 +815,6 @@ class AuthHandler: now_ms = self._clock.time_msec() if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms: - raise SynapseError( HTTPStatus.FORBIDDEN, "The supplied refresh token has expired", @@ -2259,7 +2258,6 @@ class PasswordAuthProvider: async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - # call all of the on_logged_out callbacks for callback in self.on_logged_out_callbacks: try: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index a5798e9483..1fb23cc9bf 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -497,9 +497,11 @@ class DirectoryHandler: raise SynapseError(403, "Not allowed to publish room") # Check if publishing is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 83f53ceb88..50317ec753 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -188,7 +188,6 @@ class E2eRoomKeysHandler: # XXX: perhaps we should use a finer grained lock here? async with self._upload_linearizer.queue(user_id): - # Check that the version we're trying to upload is the current version try: version_info = await self.store.get_e2e_room_keys_version_info(user_id) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 46dd63c3f0..c508861b6a 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -236,7 +236,6 @@ class EventAuthHandler: # in any of them. allowed_rooms = await self.get_rooms_that_allow_join(state_ids) if not await self.is_user_in_rooms(allowed_rooms, user_id): - # If this is a remote request, the user might be in an allowed room # that we do not know about. if get_domain_from_id(user_id) != self._server_name: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 1a29abde98..aead0b44b9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -124,7 +124,6 @@ class InitialSyncHandler: as_client_event: bool = True, include_archived: bool = False, ) -> JsonDict: - memberships = [Membership.INVITE, Membership.JOIN] if include_archived: memberships.append(Membership.LEAVE) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 87af31aa27..4ad2233573 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler): ) if self.unpersisted_users_changes: - await self.store.update_presence( [ self.user_to_current_state[user_id] @@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler): now = self.clock.time_msec() with Measure(self.clock, "presence_update_states"): - # NOTE: We purposefully don't await between now and when we've # calculated what we want to do with the new states, to avoid races. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 37c87c8351..a26ec02284 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -868,9 +868,11 @@ class RoomCreationHandler: ) # Check whether this visibility value is blocked by a third party module - allowed_by_third_party_rules = await ( - self.third_party_event_rules.check_visibility_can_be_modified( - room_id, visibility + allowed_by_third_party_rules = ( + await ( + self.third_party_event_rules.check_visibility_can_be_modified( + room_id, visibility + ) ) ) if not allowed_by_third_party_rules: diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index c73d2adaad..5d4ca0e2d2 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -374,7 +374,7 @@ class RoomBatchHandler: # correct stream_ordering as they are backfilled (which decrements). # Events are sorted by (topological_ordering, stream_ordering) # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): + for event, context in reversed(events_to_persist): # This call can't raise `PartialStateConflictError` since we forbid # use of the historical batch API during partial state await self.event_creation_handler.handle_new_client_event( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4e4595312c..fd6d946c37 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1297,7 +1297,6 @@ class SyncHandler: return RoomNotifCounts.empty() with Measure(self.clock, "unread_notifs_for_room_id"): - return await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5aed71262f..c70eee649c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -524,6 +524,7 @@ def whitelisted_homeserver(destination: str) -> bool: # Start spans and scopes + # Could use kwargs but I want these to be explicit def start_active_span( operation_name: str, diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index b01372565d..8ce5887229 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -87,7 +87,6 @@ class LaterGauge(Collector): ] def collect(self) -> Iterable[Metric]: - g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) try: diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py index b7d47ce3e7..a22c4e5bbd 100644 --- a/synapse/metrics/_gc.py +++ b/synapse/metrics/_gc.py @@ -139,7 +139,6 @@ def install_gc_manager() -> None: class PyPyGCStats(Collector): def collect(self) -> Iterable[Metric]: - # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB'). diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 5fc38431ba..8f834be774 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -330,7 +330,6 @@ class BulkPushRuleEvaluator: context: EventContext, event_id_to_event: Mapping[str, EventBase], ) -> None: - if ( not event.internal_metadata.is_notifiable() or event.internal_metadata.is_historical() diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 2374f810c9..111ec07e64 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -265,7 +265,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] - return {} async def _handle_request( # type: ignore[override] diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index ecea6fc915..cc3929dcf5 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -195,7 +195,6 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): async def _serialize_payload( # type: ignore[override] user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: - return { "user_id": user_id, "device_id": device_id, diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index fd1c0ec6af..dfc061eb5e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -328,7 +328,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): outbound_redis_connection: txredisapi.ConnectionHandler, channel_names: List[str], ): - super().__init__( hs, uuid="subscriber", diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 14b6705862..ad9b760713 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -139,7 +139,6 @@ class EventsStream(Stream): current_token: Token, target_row_count: int, ) -> StreamUpdateResult: - # the events stream merges together three separate sources: # * new events # * current_state changes diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1d6e4982d7..4de56bf13f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -75,7 +75,6 @@ class RoomRestV2Servlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) await assert_user_is_admin(self._auth, requester) @@ -144,7 +143,6 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) if not RoomID.is_valid(room_id): @@ -181,7 +179,6 @@ class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, delete_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self._auth, request) delete_status = self._pagination_handler.get_delete_status(delete_id) @@ -438,7 +435,6 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 0c0bf540b9..7cc4db20d6 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -683,8 +683,12 @@ class AccountValidityRenewServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if self.account_activity_handler.on_legacy_admin_request_callback: - expiration_ts = await ( - self.account_activity_handler.on_legacy_admin_request_callback(request) + expiration_ts = ( + await ( + self.account_activity_handler.on_legacy_admin_request_callback( + request + ) + ) ) else: body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index eb77337044..276a1b405d 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -97,7 +97,6 @@ class AuthRestServlet(RestServlet): return None async def on_POST(self, request: Request, stagetype: str) -> None: - session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index cc1c2f9731..236199897c 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -79,7 +79,6 @@ class CreateFilterRestServlet(RestServlet): async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 3cb1e7e375..bce806f2bb 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -628,10 +628,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = await ( - self.password_auth_provider.get_username_for_registration( - auth_result, - params, + desired_username = ( + await ( + self.password_auth_provider.get_username_for_registration( + auth_result, + params, + ) ) ) @@ -682,9 +684,11 @@ class RegisterRestServlet(RestServlet): session_id ) - display_name = await ( - self.password_auth_provider.get_displayname_for_registration( - auth_result, params + display_name = ( + await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) ) ) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 6e035afcce..ef8334ae25 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -270,7 +270,6 @@ async def respond_with_responder( logger.debug("Responding to media request with responder %s", responder) add_file_headers(request, media_type, file_size, upload_name) try: - await responder.write_to_consumer(request) except Exception as e: # The majority of the time this will be due to the client having gone diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 9480cc5763..f909a4fb9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -38,7 +38,6 @@ class ThumbnailError(Exception): class Thumbnailer: - FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} @staticmethod diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 8e61aba454..0d75d9739a 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - for (user_id, messages_by_device) in edu["messages"].items(): - for (device_id, msg) in messages_by_device.items(): + for user_id, messages_by_device in edu["messages"].items(): + for device_id, msg in messages_by_device.items(): with start_active_span("store_outgoing_to_device_message"): set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) @@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, ) -> Tuple[int, bool]: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1ca66d57d4..0dd15f16ff 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -512,7 +512,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): results.append(("org.matrix.signing_key_update", result)) if issue_8631_logger.isEnabledFor(logging.DEBUG): - for (user_id, edu) in results: + for user_id, edu in results: issue_8631_logger.debug( "device update to %s for %s from %s to %s: %s", destination, @@ -1316,7 +1316,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) """ count = 0 - for (destination, user_id, stream_id, device_id) in rows: + for destination, user_id, stream_id, device_id in rows: txn.execute( delete_sql, (destination, user_id, stream_id, stream_id, device_id) ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 6240f9a75e..9f8d2e4bea 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No backup with that version exists") values = [] - for (room_id, session_id, room_key) in room_keys: + for room_id, session_id, room_key in room_keys: values.append( ( user_id, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2c2d145666..b9c39b1718 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # add each cross-signing signature to the correct device in the result dict. - for (user_id, key_id, device_id, signature) in cross_sigs_result: + for user_id, key_id, device_id, signature in cross_sigs_result: target_device_result = result[user_id][device_id] # We've only looked up cross-signatures for non-deleted devices with key # data. @@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker # devices. user_list = [] user_device_list = [] - for (user_id, device_id) in query_list: + for user_id, device_id in query_list: if device_id is None: user_list.append(user_id) else: @@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker txn.execute(sql, query_params) - for (user_id, device_id, display_name, key_json) in txn: + for user_id, device_id, display_name, key_json in txn: assert device_id is not None if include_deleted_devices: deleted_devices.remove((user_id, device_id)) @@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_query_clauses = [] signature_query_params = [] - for (user_id, device_id) in device_query: + for user_id, device_id in device_query: signature_query_clauses.append( "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca780cca36..ff3edeb716 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas latest_events: List[str], limit: int, ) -> List[str]: - seen_events = set(earliest_events) front = set(latest_events) - seen_events event_results: List[str] = [] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7996cbb557..73b8aea16c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -469,7 +469,6 @@ class PersistEventsStore: txn: LoggingTransaction, events: List[EventBase], ) -> None: - # We only care about state events, so this if there are no state events. if not any(e.is_state() for e in events): return diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 584536111d..0a275e6ce6 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): nbrows = 0 last_row_event_id = "" - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) @@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): results = list(txn) # (event_id, parent_id, rel_type) for each relation relations_to_insert: List[Tuple[str, str, str]] = [] - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) except Exception as e: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6d0ef10258..b7e7498125 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1493,7 +1493,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(redactions_sql + clause, args) - for (redacter, redacted) in txn: + for redacter, redacted in txn: d = event_dict.get(redacted) if d: d.redactions.append(redacter) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index b202c5eb87..fa8be214ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[Dict[str, Any]], int]: - # Set ordering order_by_column = MediaSortOrder(order_by).value diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df53e726e6..fddbc07afa 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -344,7 +344,6 @@ class PusherWorkerStore(SQLBaseStore): last_user = progress.get("last_user", "") def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT name FROM users WHERE deactivated = ? and name > ? @@ -392,7 +391,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, access_token FROM pushers AS p LEFT JOIN access_tokens AS a ON (p.access_token = a.id) @@ -449,7 +447,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey FROM pushers AS p diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dddf49c2d5..92a82240ab 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): def _populate_receipt_event_stream_ordering_txn( txn: LoggingTransaction, ) -> bool: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 644bbb8878..39f89291b2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2168,7 +2168,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): def _get_event_report_txn( txn: LoggingTransaction, report_id: int ) -> Optional[Dict[str, Any]]: - sql = """ SELECT er.id, diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 3fe433f66c..a7aae661d8 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore): class SearchBackgroundUpdateStore(SearchWorkerStore): - EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" @@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index ba325d390b..ebb2ae964f 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d7b7d0c3c9..d3393d8e49 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore): insert_cols = [] qargs = [] - for (key, val) in chain( + for key, val in chain( keyvalues.items(), absolutes.items(), additive_relatives.items() ): insert_cols.append(key) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 818c46182e..ac5fbf6b86 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" + # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 6b33d809b6..6d72bd9f67 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def get_destination_rooms_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: - if direction == Direction.BACKWARDS: order = "DESC" else: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 30af4b3b6c..c3f2b61bd5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -98,7 +98,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def _populate_user_directory_createtables( self, progress: JsonDict, batch_size: int ) -> int: - # Get all the rooms that we want to process. def _make_staging_area(txn: LoggingTransaction) -> None: sql = ( diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index d743282f13..097dea5182 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 1a7232b276..89b1faa6c8 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -257,14 +257,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( + non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( + member_state, incomplete_groups_m = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 6c335a9315..2a1c6fa31b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -563,7 +563,7 @@ def _apply_module_schemas( """ # This is the old way for password_auth_provider modules to make changes # to the database. This should instead be done using the module API - for (mod, _config) in config.authproviders.password_providers: + for mod, _config in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) @@ -591,7 +591,7 @@ def _apply_module_schema_files( (modname,), ) applied_deltas = {d for d, in cur} - for (name, stream) in names_and_streams: + for name, stream in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/types/state.py b/synapse/types/state.py index 743a4f9217..4b3071acce 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -120,7 +120,7 @@ class StateFilter: def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: """The inverse to `from_types`.""" - for (event_type, state_keys) in self.types.items(): + for event_type, state_keys in self.types.items(): if state_keys is None: yield event_type, None else: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 9387632d0d..6ffa56217e 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -98,7 +98,6 @@ class EvictionReason(Enum): @attr.s(slots=True, auto_attribs=True) class CacheMetric: - _cache: Sized _cache_type: str _cache_name: str diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py index 3b1e205700..1c0fde4966 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py @@ -183,7 +183,7 @@ def check_requirements(extra: Optional[str] = None) -> None: deps_unfulfilled = [] errors = [] - for (requirement, must_be_installed) in dependencies: + for requirement, must_be_installed in dependencies: try: dist: metadata.Distribution = metadata.distribution(requirement.name) except metadata.PackageNotFoundError: diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index f97f98a057..d00d34e652 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -211,7 +211,6 @@ def _check_yield_points( result = Failure() if current_context() != expected_context: - # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the # deferred we waited on didn't follow the rules, or we forgot to diff --git a/synmark/__main__.py b/synmark/__main__.py index 35a59e347a..19de639187 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py @@ -34,12 +34,10 @@ def make_test(main): """ def _main(loops): - reactor = make_reactor() file_out = StringIO() with redirect_stderr(file_out): - d = Deferred() d.addCallback(lambda _: ensureDeferred(main(reactor, loops))) diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index 9419892e95..8beb077e0a 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -30,7 +30,6 @@ from synapse.util import Clock class LineCounter(LineOnlyReceiver): - delimiter = b"\n" def __init__(self, *args, **kwargs): diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 35dd9a20df..33af8770fd 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -24,7 +24,6 @@ from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, @@ -37,7 +36,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): return config def test_complexity_simple(self) -> None: - u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") @@ -71,7 +69,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(complexity, 1.23) def test_join_too_large(self) -> None: - u1 = self.register_user("u1", "pass") handler = self.hs.get_room_member_handler() @@ -131,7 +128,6 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def test_join_too_large_once_joined(self) -> None: - u1 = self.register_user("u1", "pass") u1_token = self.login("u1", "pass") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index bba6469b55..6c7738d810 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -34,7 +34,6 @@ from tests.unittest import override_config class FederationServerTests(unittest.FederatingHomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py index 137deab138..d6f43a98fc 100644 --- a/tests/handlers/test_sso.py +++ b/tests/handlers/test_sso.py @@ -113,7 +113,6 @@ async def mock_get_file( headers: Optional[RawHeaders] = None, is_allowed_content_type: Optional[Callable[[str], bool]] = None, ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: - fake_response = FakeResponse(code=404) if url == "http://my.server/me.png": fake_response = FakeResponse( diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index f1a50c5bcb..d11ded6c5b 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6 class StatsRoomTests(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 7748f56ee6..6ab13357f9 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -46,7 +46,6 @@ class SrvResolverTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup() -> Generator["Deferred[object]", object, List[Server]]: - with LoggingContext("one") as ctx: resolve_d = resolver.resolve_service(service_name) result: List[Server] diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 9cfe1ad0de..f6d6684985 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -149,7 +149,7 @@ class BlacklistingAgentTest(TestCase): self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1" # Configure the reactor's DNS resolver. - for (domain, ip) in ( + for domain, ip in ( (self.safe_domain, self.safe_ip), (self.unsafe_domain, self.unsafe_ip), (self.allowed_domain, self.allowed_ip), diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 199e3d7b70..dce6899e78 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -33,7 +33,6 @@ from tests.unittest import HomeserverTestCase, override_config class TestBulkPushRuleEvaluator(HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7563f33fdc..0a3aca5c50 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -39,7 +39,6 @@ class _User: class EmailPusherTests(HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -48,7 +47,6 @@ class EmailPusherTests(HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["email"] = { "enable_notifs": True, diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index ddca9d696c..57c781a0c3 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -64,7 +64,6 @@ def patch__eq__(cls: object) -> Callable[[], None]: class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): - STORE_TYPE = EventsWorkerStore def setUp(self) -> None: diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 03f2112b07..aaa488bced 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -28,7 +28,6 @@ from tests import unittest class DeviceRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): class DevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, @@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index db77a45ae3..f41319a5b6 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -594,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -724,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -821,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 453a6e979c..9dbb778679 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): class JoinAliasRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, room.register_servlets, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index f71ff46d87..28b999573e 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class ServerNoticeTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e2ee1a1766..2b05dffc7d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): class DeactivateTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase): class WhoamiTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - servlets = [ account.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index a144610078..0d8fe77b88 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -52,7 +52,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): class FallbackAuthTests(unittest.HomeserverTestCase): - servlets = [ auth.register_servlets, register.register_servlets, @@ -60,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = True diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index d1751e1557..c16e8d43f4 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -26,7 +26,6 @@ from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index b1ca81a911..bb845179d3 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): hijack_auth = False def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["form_secret"] = "123abc" diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index 7a88aa2cda..6490e883bf 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -28,7 +28,6 @@ from tests.unittest import override_config class DirectoryTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, directory.register_servlets, diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 9fa1f82dfe..f31ebc8021 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -26,7 +26,6 @@ from tests import unittest class EphemeralMessageTestCase(unittest.HomeserverTestCase): - user_id = "@user:test" servlets = [ diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index a9b7db9db2..54df2a252c 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_registration_captcha"] = False config["enable_registration"] = True @@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") @@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # register an account self.user_id = self.register_user("sid1", "pass") self.token = self.login(self.user_id, "pass") diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 830762fd53..91678abf13 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha" class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" hijack_auth = True EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index ff5baa9f0a..62acf4f44e 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [ class LoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, ] diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index 6aedc1a11c..b8187db982 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, admin.register_servlets, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 67e16880e6..dcbb125a3b 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -35,7 +35,6 @@ class PresenceTestCase(unittest.HomeserverTestCase): servlets = [presence.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.presence_handler = Mock(spec=PresenceHandler) self.presence_handler.set_state.return_value = make_awaitable(None) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 8de5a342ae..27c93ad761 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -30,7 +30,6 @@ from tests import unittest class ProfileTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase): class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, @@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets_for_client_rest_resource, login.register_servlets, diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 4c561f9525..b228dba861 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -40,7 +40,6 @@ from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register.register_servlets, @@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): class AccountValidityTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - servlets = [ register.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index c0eb5d01a6..8dbd64be55 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous" class RendezvousServletTestCase(unittest.HomeserverTestCase): - servlets = [ rendezvous.register_servlets, ] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index cfad182b2f..4dd763096d 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase): servlets = [room.register_servlets, room.register_deprecated_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.hs = self.setup_test_homeserver( "red", federation_http_client=None, @@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase): rmcreator_id = "@notme:red" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase): class RoomJoinTestCase(RoomBase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): hijack_auth = False def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # Register the user who does the searching self.user_id2 = self.register_user("user", "pass") self.access_token = self.login("user", "pass") @@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase): class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() @@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["allow_public_rooms_without_auth"] = True self.hs = self.setup_test_homeserver(config=config) @@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase): class ContextTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): class ThreepidInviteTestCase(unittest.HomeserverTestCase): - servlets = [ admin.register_servlets, login.register_servlets, @@ -3438,7 +3427,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): """ Test allowing/blocking threepid invites with a spam-check module. - In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index b9047194dd..9c876c7a32 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -41,7 +41,6 @@ from tests.server import TimedOutException class FilterTestCase(unittest.HomeserverTestCase): - user_id = "@apple:test" servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): class ExcludeRoomTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 5fa3440691..c0f93f898a 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """Tests that a forbidden event is forbidden from being sent, but an allowed one can be sent. """ + # patch the rules module with a Mock which will return False for some event # types async def check( @@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_modify_event(self) -> None: """The module can return a modified version of the event""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] @@ -275,6 +277,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_message_edit(self) -> None: """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event async def check( ev: EventBase, state: StateMap[EventBase] diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index 23f227aed6..b59d9dfd4d 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -31,7 +31,6 @@ from tests.utils import MockClock class MediaRetentionTestCase(unittest.HomeserverTestCase): - ONE_DAY_IN_MS = 24 * 60 * 60 * 1000 THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 17a3b06a8e..8ed27179c4 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -52,7 +52,6 @@ from tests.utils import default_config class MediaStorageTests(unittest.HomeserverTestCase): - needs_threadpool = True def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -207,7 +206,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches: List[ Tuple[ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", @@ -268,7 +266,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_resource = hs.get_media_repository_resource() self.download_resource = media_resource.children[b"download"] self.thumbnail_resource = media_resource.children[b"thumbnail"] diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 2c321f8d04..6fcf60ce19 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["url_preview_enabled"] = True config["max_spider_size"] = 9999999 @@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() self.preview_url = self.media_repo.children[b"preview_url"] @@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): addressTypes: Optional[Sequence[Type[IAddress]]] = None, transportSemantics: str = "TCP", ) -> IResolutionReceiver: - resolution = HostResolution(hostName) resolutionReceiver.resolutionBegan(resolution) if hostName not in self.lookups: diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 6540ed53f1..3fdf5a6d52 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -25,7 +25,6 @@ from tests import unittest class ConsentNoticesTests(unittest.HomeserverTestCase): - servlets = [ sync.register_servlets, synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -34,7 +33,6 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - tmpdir = self.mktemp() os.mkdir(tmpdir) self.consent_notice_message = "consent %(consent_uri)s" diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 373707b275..b6d5c474b0 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, devices.register_servlets, diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index ac77aec003..71db47405e 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -26,7 +26,6 @@ from tests.unittest import HomeserverTestCase class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, @@ -62,6 +61,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase): keys and expected receipt key-values after duplicate receipts have been removed. """ + # First, undo the background update. def drop_receipts_unique_index(txn: LoggingTransaction) -> None: txn.execute(f"DROP INDEX IF EXISTS {index_name}") diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 3108ca3444..dbd8f3a85e 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -27,7 +27,6 @@ from tests.unittest import HomeserverTestCase class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 7f7f4ef892..cd0079871c 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -656,7 +656,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index a10e5fa8b1..73d11e7786 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -417,7 +417,6 @@ class EventChainStoreTestCase(HomeserverTestCase): def fetch_chains( self, events: List[EventBase] ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: - # Fetch the map from event ID -> (chain ID, sequence number) rows = self.get_success( self.store.db_pool.simple_select_many_batch( @@ -492,7 +491,6 @@ class LinkMapTestCase(unittest.TestCase): class EventChainBackgroundUpdateTestCase(HomeserverTestCase): - servlets = [ admin.register_servlets, room.register_servlets, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 8fc7936ab0..3e1984c15c 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -672,7 +672,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): complete_event_dict_map: Dict[str, JsonDict] = {} stream_ordering = 0 - for (event_id, prev_event_ids) in event_graph.items(): + for event_id, prev_event_ids in event_graph.items(): depth = depth_map[event_id] complete_event_dict_map[event_id] = { diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 76c06a9d1e..aa19c3bd30 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -774,7 +774,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): self.assertEqual(r, 3) # add a bunch of dummy events to the events table - for (stream_ordering, ts) in ( + for stream_ordering, ts in ( (3, 110), (4, 120), (5, 120), diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index d8f42c5d05..857e2caf2e 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase class PurgeTests(HomeserverTestCase): - user_id = "@red:server" servlets = [room.register_servlets] diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 8794401823..f4c4661aaf 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -27,7 +27,6 @@ from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): - servlets = [ login.register_servlets, register_servlets_for_client_rest_resource, @@ -35,7 +34,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # type: ignore[override] - # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastores().main @@ -48,7 +46,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.u_charlie = UserID.from_string("@charlie:elsewhere") def test_one_member(self) -> None: - # Alice creates the room, and is automatically joined self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f730b888f7..e82c03f597 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -242,7 +242,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -259,7 +259,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -272,7 +272,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -289,7 +289,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -309,7 +309,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -327,7 +327,7 @@ class StateStoreTestCase(HomeserverTestCase): state_dict, ) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -341,7 +341,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -392,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -404,7 +404,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertDictEqual({}, state_dict) room_id = self.room.to_string() - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -417,7 +417,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -428,7 +428,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -447,7 +447,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -459,7 +459,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -473,7 +473,7 @@ class StateStoreTestCase(HomeserverTestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -485,7 +485,7 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( + state_dict, is_all = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( diff --git a/tests/test_mau.py b/tests/test_mau.py index 4e7665a22b..ff21098a59 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -32,7 +32,6 @@ from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): - servlets = [register.register_servlets, sync.register_servlets] def default_config(self) -> JsonDict: -- cgit 1.5.1 From ec79870f1422be47e8d6e85f315799888278969b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Feb 2023 16:06:42 -0500 Subject: Fix a typo in MSC3873 config option. (#15138) Previously the experimental configuration option referred to the wrong MSC number. --- changelog.d/15138.misc | 1 + synapse/config/experimental.py | 4 ++-- synapse/push/bulk_push_rule_evaluator.py | 12 ++++++------ tests/push/test_push_rule_evaluator.py | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/15138.misc (limited to 'synapse/config') diff --git a/changelog.d/15138.misc b/changelog.d/15138.misc new file mode 100644 index 0000000000..fb706b27f2 --- /dev/null +++ b/changelog.d/15138.misc @@ -0,0 +1 @@ +Fix a typo in an experimental config setting. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 54c91953e1..bc38fae0b6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -175,8 +175,8 @@ class ExperimentalConfig(Config): ) # MSC3873: Disambiguate event_match keys. - self.msc3783_escape_event_match_key = experimental.get( - "msc3783_escape_event_match_key", False + self.msc3873_escape_event_match_key = experimental.get( + "msc3873_escape_event_match_key", False ) # MSC3952: Intentional mentions, this depends on MSC3758. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8f834be774..3c4a152d6b 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -276,7 +276,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events[relation_type] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) reply_event_id = ( @@ -294,7 +294,7 @@ class BulkPushRuleEvaluator: if related_event is not None: related_events["m.in_reply_to"] = _flatten_dict( related_event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ) # indicate that this is from a fallback relation. @@ -412,7 +412,7 @@ class BulkPushRuleEvaluator: evaluator = PushRuleEvaluator( _flatten_dict( event, - msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key, ), has_mentions, user_mentions, @@ -507,7 +507,7 @@ def _flatten_dict( prefix: Optional[List[str]] = None, result: Optional[Dict[str, JsonValue]] = None, *, - msc3783_escape_event_match_key: bool = False, + msc3873_escape_event_match_key: bool = False, ) -> Dict[str, JsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, @@ -536,7 +536,7 @@ def _flatten_dict( if result is None: result = {} for key, value in d.items(): - if msc3783_escape_event_match_key: + if msc3873_escape_event_match_key: # Escape periods in the key with a backslash (and backslashes with an # extra backslash). This is since a period is used as a separator between # nested fields. @@ -552,7 +552,7 @@ def _flatten_dict( value, prefix=(prefix + [key]), result=result, - msc3783_escape_event_match_key=msc3783_escape_event_match_key, + msc3873_escape_event_match_key=msc3873_escape_event_match_key, ) # `room_version` should only ever be set when looking at the top level of an event diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index d320a12f96..4e858fd16f 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -54,7 +54,7 @@ class FlattenDictTestCase(unittest.TestCase): self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input)) self.assertEqual( {"m\\.foo.b\\\\ar": "abc"}, - _flatten_dict(input, msc3783_escape_event_match_key=True), + _flatten_dict(input, msc3873_escape_event_match_key=True), ) def test_non_string(self) -> None: -- cgit 1.5.1 From 4fc8875876374ec8f97a3b3cc344a4e3abcf769f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Feb 2023 08:26:05 -0500 Subject: Refactor media modules. (#15146) * Removes the `v1` directory from `test.rest.media.v1`. * Moves the non-REST code from `synapse.rest.media.v1` to `synapse.media`. * Flatten the `v1` directory from `synapse.rest.media`, but leave compatiblity with 3rd party media repositories and spam checkers. --- changelog.d/15146.misc | 1 + synapse/_scripts/move_remote_media_to_new_store.py | 2 +- synapse/config/repository.py | 12 +- synapse/events/spamcheck.py | 4 +- synapse/media/_base.py | 479 ++++++++ synapse/media/filepath.py | 410 +++++++ synapse/media/media_repository.py | 1038 ++++++++++++++++ synapse/media/media_storage.py | 374 ++++++ synapse/media/oembed.py | 265 +++++ synapse/media/preview_html.py | 501 ++++++++ synapse/media/storage_provider.py | 181 +++ synapse/media/thumbnailer.py | 221 ++++ synapse/rest/media/config_resource.py | 41 + synapse/rest/media/download_resource.py | 75 ++ synapse/rest/media/media_repository_resource.py | 93 ++ synapse/rest/media/preview_url_resource.py | 869 ++++++++++++++ synapse/rest/media/thumbnail_resource.py | 554 +++++++++ synapse/rest/media/upload_resource.py | 108 ++ synapse/rest/media/v1/_base.py | 470 +------- synapse/rest/media/v1/config_resource.py | 41 - synapse/rest/media/v1/download_resource.py | 76 -- synapse/rest/media/v1/filepath.py | 410 ------- synapse/rest/media/v1/media_repository.py | 1112 ------------------ synapse/rest/media/v1/media_storage.py | 365 +----- synapse/rest/media/v1/oembed.py | 265 ----- synapse/rest/media/v1/preview_html.py | 501 -------- synapse/rest/media/v1/preview_url_resource.py | 871 -------------- synapse/rest/media/v1/storage_provider.py | 172 +-- synapse/rest/media/v1/thumbnail_resource.py | 555 --------- synapse/rest/media/v1/thumbnailer.py | 221 ---- synapse/rest/media/v1/upload_resource.py | 108 -- synapse/server.py | 6 +- tests/media/__init__.py | 13 + tests/media/test_base.py | 38 + tests/media/test_filepath.py | 595 ++++++++++ tests/media/test_html_preview.py | 542 +++++++++ tests/media/test_media_storage.py | 792 +++++++++++++ tests/media/test_oembed.py | 162 +++ tests/rest/admin/test_media.py | 2 +- tests/rest/admin/test_user.py | 2 +- tests/rest/media/test_url_preview.py | 1234 ++++++++++++++++++++ tests/rest/media/v1/__init__.py | 13 - tests/rest/media/v1/test_base.py | 38 - tests/rest/media/v1/test_filepath.py | 595 ---------- tests/rest/media/v1/test_html_preview.py | 542 --------- tests/rest/media/v1/test_media_storage.py | 792 ------------- tests/rest/media/v1/test_oembed.py | 162 --- tests/rest/media/v1/test_url_preview.py | 1234 -------------------- 48 files changed, 8612 insertions(+), 8545 deletions(-) create mode 100644 changelog.d/15146.misc create mode 100644 synapse/media/_base.py create mode 100644 synapse/media/filepath.py create mode 100644 synapse/media/media_repository.py create mode 100644 synapse/media/media_storage.py create mode 100644 synapse/media/oembed.py create mode 100644 synapse/media/preview_html.py create mode 100644 synapse/media/storage_provider.py create mode 100644 synapse/media/thumbnailer.py create mode 100644 synapse/rest/media/config_resource.py create mode 100644 synapse/rest/media/download_resource.py create mode 100644 synapse/rest/media/media_repository_resource.py create mode 100644 synapse/rest/media/preview_url_resource.py create mode 100644 synapse/rest/media/thumbnail_resource.py create mode 100644 synapse/rest/media/upload_resource.py delete mode 100644 synapse/rest/media/v1/config_resource.py delete mode 100644 synapse/rest/media/v1/download_resource.py delete mode 100644 synapse/rest/media/v1/filepath.py delete mode 100644 synapse/rest/media/v1/media_repository.py delete mode 100644 synapse/rest/media/v1/oembed.py delete mode 100644 synapse/rest/media/v1/preview_html.py delete mode 100644 synapse/rest/media/v1/preview_url_resource.py delete mode 100644 synapse/rest/media/v1/thumbnail_resource.py delete mode 100644 synapse/rest/media/v1/thumbnailer.py delete mode 100644 synapse/rest/media/v1/upload_resource.py create mode 100644 tests/media/__init__.py create mode 100644 tests/media/test_base.py create mode 100644 tests/media/test_filepath.py create mode 100644 tests/media/test_html_preview.py create mode 100644 tests/media/test_media_storage.py create mode 100644 tests/media/test_oembed.py create mode 100644 tests/rest/media/test_url_preview.py delete mode 100644 tests/rest/media/v1/__init__.py delete mode 100644 tests/rest/media/v1/test_base.py delete mode 100644 tests/rest/media/v1/test_filepath.py delete mode 100644 tests/rest/media/v1/test_html_preview.py delete mode 100644 tests/rest/media/v1/test_media_storage.py delete mode 100644 tests/rest/media/v1/test_oembed.py delete mode 100644 tests/rest/media/v1/test_url_preview.py (limited to 'synapse/config') diff --git a/changelog.d/15146.misc b/changelog.d/15146.misc new file mode 100644 index 0000000000..8de5f95239 --- /dev/null +++ b/changelog.d/15146.misc @@ -0,0 +1 @@ +Refactor the media modules. diff --git a/synapse/_scripts/move_remote_media_to_new_store.py b/synapse/_scripts/move_remote_media_to_new_store.py index 819afaaca6..0dd36bee20 100755 --- a/synapse/_scripts/move_remote_media_to_new_store.py +++ b/synapse/_scripts/move_remote_media_to_new_store.py @@ -37,7 +37,7 @@ import os import shutil import sys -from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.media.filepath import MediaFilePaths logger = logging.getLogger() diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2da40c09f0..ecb3edbe3a 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -178,11 +178,13 @@ class ContentRepositoryConfig(Config): for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to # expose FileStorageProviderBackend - if provider_config["module"] == "file_system": - provider_config["module"] = ( - "synapse.rest.media.v1.storage_provider" - ".FileStorageProviderBackend" - ) + if ( + provider_config["module"] == "file_system" + or provider_config["module"] == "synapse.rest.media.v1.storage_provider" + ): + provider_config[ + "module" + ] = "synapse.media.storage_provider.FileStorageProviderBackend" provider_class, parsed_config = load_module( provider_config, ("media_storage_providers", "" % i) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 623a2c71ea..765c15bb51 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -33,8 +33,8 @@ from typing_extensions import Literal import synapse from synapse.api.errors import Codes from synapse.logging.opentracing import trace -from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.media_storage import ReadableFileWrapper +from synapse.media._base import FileInfo +from synapse.media.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import JsonDict, RoomAlias, UserProfile from synapse.util.async_helpers import delay_cancellation, maybe_awaitable diff --git a/synapse/media/_base.py b/synapse/media/_base.py new file mode 100644 index 0000000000..ef8334ae25 --- /dev/null +++ b/synapse/media/_base.py @@ -0,0 +1,479 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import urllib +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type + +import attr + +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender +from twisted.web.server import Request + +from synapse.api.errors import Codes, SynapseError, cs_error +from synapse.http.server import finish_request, respond_with_json +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.util.stringutils import is_ascii, parse_and_validate_server_name + +logger = logging.getLogger(__name__) + +# list all text content types that will have the charset default to UTF-8 when +# none is given +TEXT_CONTENT_TYPES = [ + "text/css", + "text/csv", + "text/html", + "text/calendar", + "text/plain", + "text/javascript", + "application/json", + "application/ld+json", + "application/rtf", + "image/svg+xml", + "text/xml", +] + + +def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: + """Parses the server name, media ID and optional file name from the request URI + + Also performs some rough validation on the server name. + + Args: + request: The `Request`. + + Returns: + A tuple containing the parsed server name, media ID and optional file name. + + Raises: + SynapseError(404): if parsing or validation fail for any reason + """ + try: + # The type on postpath seems incorrect in Twisted 21.2.0. + postpath: List[bytes] = request.postpath # type: ignore + assert postpath + + # This allows users to append e.g. /test.png to the URL. Useful for + # clients that parse the URL to see content type. + server_name_bytes, media_id_bytes = postpath[:2] + server_name = server_name_bytes.decode("utf-8") + media_id = media_id_bytes.decode("utf8") + + # Validate the server name, raising if invalid + parse_and_validate_server_name(server_name) + + file_name = None + if len(postpath) > 2: + try: + file_name = urllib.parse.unquote(postpath[-1].decode("utf-8")) + except UnicodeDecodeError: + pass + return server_name, media_id, file_name + except Exception: + raise SynapseError( + 404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN + ) + + +def respond_404(request: SynapseRequest) -> None: + respond_with_json( + request, + 404, + cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), + send_cors=True, + ) + + +async def respond_with_file( + request: SynapseRequest, + media_type: str, + file_path: str, + file_size: Optional[int] = None, + upload_name: Optional[str] = None, +) -> None: + logger.debug("Responding with %r", file_path) + + if os.path.isfile(file_path): + if file_size is None: + stat = os.stat(file_path) + file_size = stat.st_size + + add_file_headers(request, media_type, file_size, upload_name) + + with open(file_path, "rb") as f: + await make_deferred_yieldable(FileSender().beginFileTransfer(f, request)) + + finish_request(request) + else: + respond_404(request) + + +def add_file_headers( + request: Request, + media_type: str, + file_size: Optional[int], + upload_name: Optional[str], +) -> None: + """Adds the correct response headers in preparation for responding with the + media. + + Args: + request + media_type: The media/content type. + file_size: Size in bytes of the media, if known. + upload_name: The name of the requested file, if any. + """ + + def _quote(x: str) -> str: + return urllib.parse.quote(x.encode("utf-8")) + + # Default to a UTF-8 charset for text content types. + # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16' + if media_type.lower() in TEXT_CONTENT_TYPES: + content_type = media_type + "; charset=UTF-8" + else: + content_type = media_type + + request.setHeader(b"Content-Type", content_type.encode("UTF-8")) + if upload_name: + # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. + # + # `filename` is defined to be a `value`, which is defined by RFC2616 + # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` + # is (essentially) a single US-ASCII word, and a `quoted-string` is a + # US-ASCII string surrounded by double-quotes, using backslash as an + # escape character. Note that %-encoding is *not* permitted. + # + # `filename*` is defined to be an `ext-value`, which is defined in + # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, + # where `value-chars` is essentially a %-encoded string in the given charset. + # + # [1]: https://tools.ietf.org/html/rfc6266#section-4.1 + # [2]: https://tools.ietf.org/html/rfc2616#section-3.6 + # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1 + + # We avoid the quoted-string version of `filename`, because (a) synapse didn't + # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we + # may as well just do the filename* version. + if _can_encode_filename_as_token(upload_name): + disposition = "inline; filename=%s" % (upload_name,) + else: + disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),) + + request.setHeader(b"Content-Disposition", disposition.encode("ascii")) + + # cache for at least a day. + # XXX: we might want to turn this off for data we don't want to + # recommend caching as it's sensitive or private - or at least + # select private. don't bother setting Expires as all our + # clients are smart enough to be happy with Cache-Control + request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") + if file_size is not None: + request.setHeader(b"Content-Length", b"%d" % (file_size,)) + + # Tell web crawlers to not index, archive, or follow links in media. This + # should help to prevent things in the media repo from showing up in web + # search results. + request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex") + + +# separators as defined in RFC2616. SP and HT are handled separately. +# see _can_encode_filename_as_token. +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} + + +def _can_encode_filename_as_token(x: str) -> bool: + for c in x: + # from RFC2616: + # + # token = 1* + # + # separators = "(" | ")" | "<" | ">" | "@" + # | "," | ";" | ":" | "\" | <"> + # | "/" | "[" | "]" | "?" | "=" + # | "{" | "}" | SP | HT + # + # CHAR = + # + # CTL = + # + if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS: + return False + return True + + +async def respond_with_responder( + request: SynapseRequest, + responder: "Optional[Responder]", + media_type: str, + file_size: Optional[int], + upload_name: Optional[str] = None, +) -> None: + """Responds to the request with given responder. If responder is None then + returns 404. + + Args: + request + responder + media_type: The media/content type. + file_size: Size in bytes of the media. If not known it should be None + upload_name: The name of the requested file, if any. + """ + if not responder: + respond_404(request) + return + + # If we have a responder we *must* use it as a context manager. + with responder: + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return + + logger.debug("Responding to media request with responder %s", responder) + add_file_headers(request, media_type, file_size, upload_name) + try: + await responder.write_to_consumer(request) + except Exception as e: + # The majority of the time this will be due to the client having gone + # away. Unfortunately, Twisted simply throws a generic exception at us + # in that case. + logger.warning("Failed to write to consumer: %s %s", type(e), e) + + # Unregister the producer, if it has one, so Twisted doesn't complain + if request.producer: + request.unregisterProducer() + + finish_request(request) + + +class Responder(ABC): + """Represents a response that can be streamed to the requester. + + Responder is a context manager which *must* be used, so that any resources + held can be cleaned up. + """ + + @abstractmethod + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: + """Stream response into consumer + + Args: + consumer: The consumer to stream into. + + Returns: + Resolves once the response has finished being written + """ + raise NotImplementedError() + + def __enter__(self) -> None: # noqa: B027 + pass + + def __exit__( # noqa: B027 + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThumbnailInfo: + """Details about a generated thumbnail.""" + + width: int + height: int + method: str + # Content type of thumbnail, e.g. image/png + type: str + # The size of the media file, in bytes. + length: Optional[int] = None + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FileInfo: + """Details about a requested/uploaded file.""" + + # The server name where the media originated from, or None if local. + server_name: Optional[str] + # The local ID of the file. For local files this is the same as the media_id + file_id: str + # If the file is for the url preview cache + url_cache: bool = False + # Whether the file is a thumbnail or not. + thumbnail: Optional[ThumbnailInfo] = None + + # The below properties exist to maintain compatibility with third-party modules. + @property + def thumbnail_width(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.width + + @property + def thumbnail_height(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.height + + @property + def thumbnail_method(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.method + + @property + def thumbnail_type(self) -> Optional[str]: + if not self.thumbnail: + return None + return self.thumbnail.type + + @property + def thumbnail_length(self) -> Optional[int]: + if not self.thumbnail: + return None + return self.thumbnail.length + + +def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: + """ + Get the filename of the downloaded file by inspecting the + Content-Disposition HTTP header. + + Args: + headers: The HTTP request headers. + + Returns: + The filename, or None. + """ + content_disposition = headers.get(b"Content-Disposition", [b""]) + + # No header, bail out. + if not content_disposition[0]: + return None + + _, params = _parse_header(content_disposition[0]) + + upload_name = None + + # First check if there is a valid UTF-8 filename + upload_name_utf8 = params.get(b"filename*", None) + if upload_name_utf8: + if upload_name_utf8.lower().startswith(b"utf-8''"): + upload_name_utf8 = upload_name_utf8[7:] + # We have a filename*= section. This MUST be ASCII, and any UTF-8 + # bytes are %-quoted. + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass + + # If there isn't check for an ascii name. + if not upload_name: + upload_name_ascii = params.get(b"filename", None) + if upload_name_ascii and is_ascii(upload_name_ascii): + upload_name = upload_name_ascii.decode("ascii") + + # This may be None here, indicating we did not find a matching name. + return upload_name + + +def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]: + """Parse a Content-type like header. + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + line: header to be parsed + + Returns: + The main content-type, followed by the parameter dictionary + """ + parts = _parseparam(b";" + line) + key = next(parts) + pdict = {} + for p in parts: + i = p.find(b"=") + if i >= 0: + name = p[:i].strip().lower() + value = p[i + 1 :].strip() + + # strip double-quotes + if len(value) >= 2 and value[0:1] == value[-1:] == b'"': + value = value[1:-1] + value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"') + pdict[name] = value + + return key, pdict + + +def _parseparam(s: bytes) -> Generator[bytes, None, None]: + """Generator which splits the input on ;, respecting double-quoted sequences + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + s: header to be parsed + + Returns: + The split input + """ + while s[:1] == b";": + s = s[1:] + + # look for the next ; + end = s.find(b";") + + # if there is an odd number of " marks between here and the next ;, skip to the + # next ; instead + while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: + end = s.find(b";", end + 1) + + if end < 0: + end = len(s) + f = s[:end] + yield f.strip() + s = s[end:] diff --git a/synapse/media/filepath.py b/synapse/media/filepath.py new file mode 100644 index 0000000000..1f6441c412 --- /dev/null +++ b/synapse/media/filepath.py @@ -0,0 +1,410 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import re +import string +from typing import Any, Callable, List, TypeVar, Union, cast + +NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") + + +F = TypeVar("F", bound=Callable[..., str]) + + +def _wrap_in_base_path(func: F) -> F: + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + + @functools.wraps(func) + def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str: + path = func(self, *args, **kwargs) + return os.path.join(self.base_path, path) + + return cast(F, _wrapped) + + +GetPathMethod = TypeVar( + "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]] +) + + +def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: + """Wraps a path-returning method to check that the returned path(s) do not escape + the media store directory. + + The path-returning method may return either a single path, or a list of paths. + + The check is not expected to ever fail, unless `func` is missing a call to + `_validate_path_component`, or `_validate_path_component` is buggy. + + Args: + relative: A boolean indicating whether the wrapped method returns paths relative + to the media store directory. + + Returns: + A method which will wrap a path-returning method, adding a check to ensure that + the returned path(s) lie within the media store directory. The check will raise + a `ValueError` if it fails. + """ + + def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # Construct the path that will ultimately be used. + # We cannot guess whether `path` is relative to the media store + # directory, since the media store directory may itself be a relative + # path. + if relative: + path = os.path.join(self.base_path, path) + normalized_path = os.path.normpath(path) + + # Now that `normpath` has eliminated `../`s and `./`s from the path, + # `os.path.commonpath` can be used to check whether it lies within the + # media store directory. + if ( + os.path.commonpath([normalized_path, self.normalized_base_path]) + != self.normalized_base_path + ): + # The path resolves to outside the media store directory, + # or `self.base_path` is `.`, which is an unlikely configuration. + raise ValueError(f"Invalid media store path: {path!r}") + + # Note that `os.path.normpath`/`abspath` has a subtle caveat: + # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a + # different path if `a/b/c` is a symlink. That is, the check above is + # not perfect and may allow a certain restricted subset of untrustworthy + # paths through. Since the check above is secondary to the main + # `_validate_path_component` checks, it's less important for it to be + # perfect. + # + # As an alternative, `os.path.realpath` will resolve symlinks, but + # proves problematic if there are symlinks inside the media store. + # eg. if `url_store/` is symlinked to elsewhere, its canonical path + # won't match that of the main media store directory. + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + return _wrap_with_jail_check_inner + + +ALLOWED_CHARACTERS = set( + string.ascii_letters + + string.digits + + "_-" + + ".[]:" # Domain names, IPv6 addresses and ports in server names +) +FORBIDDEN_NAMES = { + "", + os.path.curdir, # "." for the current platform + os.path.pardir, # ".." for the current platform +} + + +def _validate_path_component(name: str) -> str: + """Checks that the given string can be safely used as a path component + + Args: + name: The path component to check. + + Returns: + The path component if valid. + + Raises: + ValueError: If `name` cannot be safely used as a path component. + """ + if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES: + raise ValueError(f"Invalid path component: {name!r}") + + return name + + +class MediaFilePaths: + """Describes where files are stored on disk. + + Most of the functions have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ + + def __init__(self, primary_base_path: str): + self.base_path = primary_base_path + self.normalized_base_path = os.path.normpath(self.base_path) + + # Refuse to initialize if paths cannot be validated correctly for the current + # platform. + assert os.path.sep not in ALLOWED_CHARACTERS + assert os.path.altsep not in ALLOWED_CHARACTERS + # On Windows, paths have all sorts of weirdness which `_validate_path_component` + # does not consider. In any case, the remote media store can't work correctly + # for certain homeservers there, since ":"s aren't allowed in paths. + assert os.name == "posix" + + @_wrap_with_jail_check(relative=True) + def local_media_filepath_rel(self, media_id: str) -> str: + return os.path.join( + "local_content", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + @_wrap_with_jail_check(relative=True) + def local_media_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + return os.path.join( + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), + ) + + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + @_wrap_with_jail_check(relative=False) + def local_media_thumbnail_dir(self, media_id: str) -> str: + """ + Retrieve the local store path of thumbnails of a given media_id + + Args: + media_id: The media ID to query. + Returns: + Path of local_thumbnails from media_id + """ + return os.path.join( + self.base_path, + "local_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + @_wrap_with_jail_check(relative=True) + def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: + return os.path.join( + "remote_content", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + ) + + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + @_wrap_with_jail_check(relative=True) + def remote_media_thumbnail_rel( + self, + server_name: str, + file_id: str, + width: int, + height: int, + content_type: str, + method: str, + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + return os.path.join( + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), + ) + + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + + # Legacy path that was used to store thumbnails previously. + # Should be removed after some time, when most of the thumbnails are stored + # using the new path. + @_wrap_with_jail_check(relative=True) + def remote_media_thumbnail_rel_legacy( + self, server_name: str, file_id: str, width: int, height: int, content_type: str + ) -> str: + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + return os.path.join( + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + _validate_path_component(file_name), + ) + + @_wrap_with_jail_check(relative=False) + def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: + return os.path.join( + self.base_path, + "remote_thumbnail", + _validate_path_component(server_name), + _validate_path_component(file_id[0:2]), + _validate_path_component(file_id[2:4]), + _validate_path_component(file_id[4:]), + ) + + @_wrap_with_jail_check(relative=True) + def url_cache_filepath_rel(self, media_id: str) -> str: + if NEW_FORMAT_ID_RE.match(media_id): + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + return os.path.join( + "url_cache", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) + else: + return os.path.join( + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + + @_wrap_with_jail_check(relative=False) + def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: + "The dirs to try and remove if we delete the media_id file" + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[:10]) + ) + ] + else: + return [ + os.path.join( + self.base_path, + "url_cache", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, "url_cache", _validate_path_component(media_id[0:2]) + ), + ] + + @_wrap_with_jail_check(relative=True) + def url_cache_thumbnail_rel( + self, media_id: str, width: int, height: int, content_type: str, method: str + ) -> str: + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + _validate_path_component(file_name), + ) + else: + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + _validate_path_component(file_name), + ) + + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + + @_wrap_with_jail_check(relative=True) + def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ) + else: + return os.path.join( + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ) + + url_cache_thumbnail_directory = _wrap_in_base_path( + url_cache_thumbnail_directory_rel + ) + + @_wrap_with_jail_check(relative=False) + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: + "The dirs to try and remove if we delete the media_id thumbnails" + # Media id is of the form + # E.g.: 2017-09-28-fsdRDt24DS234dsf + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + _validate_path_component(media_id[11:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[:10]), + ), + ] + else: + return [ + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + _validate_path_component(media_id[4:]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + _validate_path_component(media_id[2:4]), + ), + os.path.join( + self.base_path, + "url_cache_thumbnails", + _validate_path_component(media_id[0:2]), + ), + ] diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py new file mode 100644 index 0000000000..b81e3c2b0c --- /dev/null +++ b/synapse/media/media_repository.py @@ -0,0 +1,1038 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import errno +import logging +import os +import shutil +from io import BytesIO +from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +from matrix_common.types.mxc_uri import MXCUri + +import twisted.internet.error +import twisted.web.http +from twisted.internet.defer import Deferred + +from synapse.api.errors import ( + FederationDeniedError, + HttpResponseException, + NotFoundError, + RequestSendFailed, + SynapseError, +) +from synapse.config.repository import ThumbnailRequirement +from synapse.http.site import SynapseRequest +from synapse.logging.context import defer_to_thread +from synapse.media._base import ( + FileInfo, + Responder, + ThumbnailInfo, + get_filename_from_headers, + respond_404, + respond_with_responder, +) +from synapse.media.filepath import MediaFilePaths +from synapse.media.media_storage import MediaStorage +from synapse.media.storage_provider import StorageProviderWrapper +from synapse.media.thumbnailer import Thumbnailer, ThumbnailError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.async_helpers import Linearizer +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# How often to run the background job to update the "recently accessed" +# attribute of local and remote media. +UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute +# How often to run the background job to check for local and remote media +# that should be purged according to the configured media retention settings. +MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000 # 1 hour + + +class MediaRepository: + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.client = hs.get_federation_http_client() + self.clock = hs.get_clock() + self.server_name = hs.hostname + self.store = hs.get_datastores().main + self.max_upload_size = hs.config.media.max_upload_size + self.max_image_pixels = hs.config.media.max_image_pixels + + Thumbnailer.set_limits(self.max_image_pixels) + + self.primary_base_path: str = hs.config.media.media_store_path + self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) + + self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails + self.thumbnail_requirements = hs.config.media.thumbnail_requirements + + self.remote_media_linearizer = Linearizer(name="media_remote") + + self.recently_accessed_remotes: Set[Tuple[str, str]] = set() + self.recently_accessed_locals: Set[str] = set() + + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) + + # List of StorageProviders where we should search for media and + # potentially upload to. + storage_providers = [] + + for ( + clz, + provider_config, + wrapper_config, + ) in hs.config.media.media_storage_providers: + backend = clz(hs, provider_config) + provider = StorageProviderWrapper( + backend, + store_local=wrapper_config.store_local, + store_remote=wrapper_config.store_remote, + store_synchronous=wrapper_config.store_synchronous, + ) + storage_providers.append(provider) + + self.media_storage = MediaStorage( + self.hs, self.primary_base_path, self.filepaths, storage_providers + ) + + self.clock.looping_call( + self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS + ) + + # Media retention configuration options + self._media_retention_local_media_lifetime_ms = ( + hs.config.media.media_retention_local_media_lifetime_ms + ) + self._media_retention_remote_media_lifetime_ms = ( + hs.config.media.media_retention_remote_media_lifetime_ms + ) + + # Check whether local or remote media retention is configured + if ( + hs.config.media.media_retention_local_media_lifetime_ms is not None + or hs.config.media.media_retention_remote_media_lifetime_ms is not None + ): + # Run the background job to apply media retention rules routinely, + # with the duration between runs dictated by the homeserver config. + self.clock.looping_call( + self._start_apply_media_retention_rules, + MEDIA_RETENTION_CHECK_PERIOD_MS, + ) + + def _start_update_recently_accessed(self) -> Deferred: + return run_as_background_process( + "update_recently_accessed_media", self._update_recently_accessed + ) + + def _start_apply_media_retention_rules(self) -> Deferred: + return run_as_background_process( + "apply_media_retention_rules", self._apply_media_retention_rules + ) + + async def _update_recently_accessed(self) -> None: + remote_media = self.recently_accessed_remotes + self.recently_accessed_remotes = set() + + local_media = self.recently_accessed_locals + self.recently_accessed_locals = set() + + await self.store.update_cached_last_access_time( + local_media, remote_media, self.clock.time_msec() + ) + + def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None: + """Mark the given media as recently accessed. + + Args: + server_name: Origin server of media, or None if local + media_id: The media ID of the content + """ + if server_name: + self.recently_accessed_remotes.add((server_name, media_id)) + else: + self.recently_accessed_locals.add(media_id) + + async def create_content( + self, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: UserID, + ) -> MXCUri: + """Store uploaded content for a local user and return the mxc URL + + Args: + media_type: The content type of the file. + upload_name: The name of the file, if provided. + content: A file like object that is the content to store + content_length: The length of the content + auth_user: The user_id of the uploader + + Returns: + The mxc url of the stored content + """ + + media_id = random_string(24) + + file_info = FileInfo(server_name=None, file_id=media_id) + + fname = await self.media_storage.store_file(content, file_info) + + logger.info("Stored local media in file %r", fname) + + await self.store.store_local_media( + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=content_length, + user_id=auth_user, + ) + + await self._generate_thumbnails(None, media_id, media_id, media_type) + + return MXCUri(self.server_name, media_id) + + async def get_local_media( + self, request: SynapseRequest, media_id: str, name: Optional[str] + ) -> None: + """Responds to requests for local media, if exists, or returns 404. + + Args: + request: The incoming request. + media_id: The media ID of the content. (This is the same as + the file_id for local content.) + name: Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Resolves once a response has successfully been written to request + """ + media_info = await self.store.get_local_media(media_id) + if not media_info or media_info["quarantined_by"]: + respond_404(request) + return + + self.mark_recently_accessed(None, media_id) + + media_type = media_info["media_type"] + if not media_type: + media_type = "application/octet-stream" + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + url_cache = media_info["url_cache"] + + file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) + + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( + request, responder, media_type, media_length, upload_name + ) + + async def get_remote_media( + self, + request: SynapseRequest, + server_name: str, + media_id: str, + name: Optional[str], + ) -> None: + """Respond to requests for remote media. + + Args: + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as + the filename in the Content-Disposition header of the response. + + Returns: + Resolves once a response has successfully been written to request + """ + if ( + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + self.mark_recently_accessed(server_name, media_id) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + async with self.remote_media_linearizer.queue(key): + responder, media_info = await self._get_remote_media_impl( + server_name, media_id + ) + + # We deliberately stream the file outside the lock + if responder: + media_type = media_info["media_type"] + media_length = media_info["media_length"] + upload_name = name if name else media_info["upload_name"] + await respond_with_responder( + request, responder, media_type, media_length, upload_name + ) + else: + respond_404(request) + + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: + """Gets the media info associated with the remote file, downloading + if necessary. + + Args: + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + + Returns: + The media info of the file + """ + if ( + self.federation_domain_whitelist is not None + and server_name not in self.federation_domain_whitelist + ): + raise FederationDeniedError(server_name) + + # We linearize here to ensure that we don't try and download remote + # media multiple times concurrently + key = (server_name, media_id) + async with self.remote_media_linearizer.queue(key): + responder, media_info = await self._get_remote_media_impl( + server_name, media_id + ) + + # Ensure we actually use the responder so that it releases resources + if responder: + with responder: + pass + + return media_info + + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: + """Looks for media in local cache, if not there then attempt to + download from remote server. + + Args: + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the + remote server). + + Returns: + A tuple of responder and the media info of the file. + """ + media_info = await self.store.get_cached_remote_media(server_name, media_id) + + # file_id is the ID we use to track the file locally. If we've already + # seen the file then reuse the existing ID, otherwise generate a new + # one. + + # If we have an entry in the DB, try and look for it + if media_info: + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + + if media_info["quarantined_by"]: + logger.info("Media is quarantined") + raise NotFoundError() + + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" + + responder = await self.media_storage.fetch_media(file_info) + if responder: + return responder, media_info + + # Failed to find the file anywhere, lets download it. + + try: + media_info = await self._download_remote_file( + server_name, + media_id, + ) + except SynapseError: + raise + except Exception as e: + # An exception may be because we downloaded media in another + # process, so let's check if we magically have the media. + media_info = await self.store.get_cached_remote_media(server_name, media_id) + if not media_info: + raise e + + file_id = media_info["filesystem_id"] + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" + file_info = FileInfo(server_name, file_id) + + # We generate thumbnails even if another process downloaded the media + # as a) it's conceivable that the other download request dies before it + # generates thumbnails, but mainly b) we want to be sure the thumbnails + # have finished being generated before responding to the client, + # otherwise they'll request thumbnails and get a 404 if they're not + # ready yet. + await self._generate_thumbnails( + server_name, media_id, file_id, media_info["media_type"] + ) + + responder = await self.media_storage.fetch_media(file_info) + return responder, media_info + + async def _download_remote_file( + self, + server_name: str, + media_id: str, + ) -> dict: + """Attempt to download the remote file from the given server name, + using the given file_id as the local id. + + Args: + server_name: Originating server + media_id: The media ID of the content (as defined by the + remote server). This is different than the file_id, which is + locally generated. + file_id: Local file ID + + Returns: + The media info of the file. + """ + + file_id = random_string(24) + + file_info = FileInfo(server_name=server_name, file_id=file_id) + + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + request_path = "/".join( + ("/_matrix/media/r0/download", server_name, media_id) + ) + try: + length, headers = await self.client.get_file( + server_name, + request_path, + output_stream=f, + max_size=self.max_upload_size, + args={ + # tell the remote server to 404 if it doesn't + # recognise the server_name, to make sure we don't + # end up with a routing loop. + "allow_remote": "false" + }, + ) + except RequestSendFailed as e: + logger.warning( + "Request failed fetching remote media %s/%s: %r", + server_name, + media_id, + e, + ) + raise SynapseError(502, "Failed to fetch remote media") + + except HttpResponseException as e: + logger.warning( + "HTTP error fetching remote media %s/%s: %s", + server_name, + media_id, + e.response, + ) + if e.code == twisted.web.http.NOT_FOUND: + raise e.to_synapse_error() + raise SynapseError(502, "Failed to fetch remote media") + + except SynapseError: + logger.warning( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise + except NotRetryingDestination: + logger.warning("Not retrying destination %r", server_name) + raise SynapseError(502, "Failed to fetch remote media") + except Exception: + logger.exception( + "Failed to fetch remote media %s/%s", server_name, media_id + ) + raise SynapseError(502, "Failed to fetch remote media") + + await finish() + + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) + + logger.info("Stored remote media in file %r", fname) + + media_info = { + "media_type": media_type, + "media_length": length, + "upload_name": upload_name, + "created_ts": time_now_ms, + "filesystem_id": file_id, + } + + return media_info + + def _get_thumbnail_requirements( + self, media_type: str + ) -> Tuple[ThumbnailRequirement, ...]: + scpos = media_type.find(";") + if scpos > 0: + media_type = media_type[:scpos] + return self.thumbnail_requirements.get(media_type, ()) + + def _generate_thumbnail( + self, + thumbnailer: Thumbnailer, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[BytesIO]: + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, + ) + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = thumbnailer.transpose() + + if t_method == "crop": + return thumbnailer.crop(t_width, t_height, t_type) + elif t_method == "scale": + t_width, t_height = thumbnailer.aspect(t_width, t_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + return thumbnailer.scale(t_width, t_height, t_type) + + return None + + async def generate_local_exact_thumbnail( + self, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + url_cache: bool, + ) -> Optional[str]: + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(None, media_id, url_cache=url_cache) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None + + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=None, + file_id=media_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + output_path = await self.media_storage.store_file( + t_byte_source, file_info + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) + + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + return output_path + + # Could not generate thumbnail. + return None + + async def generate_remote_exact_thumbnail( + self, + server_name: str, + file_id: str, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[str]: + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None + + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) + + if t_byte_source: + try: + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + output_path = await self.media_storage.store_file( + t_byte_source, file_info + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) + + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + + return output_path + + # Could not generate thumbnail. + return None + + async def _generate_thumbnails( + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: + """Generate and store thumbnails for an image. + + Args: + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as + the file_id for local content) + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, + used exclusively by the url previewer + + Returns: + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. + """ + requirements = self._get_thumbnail_requirements(media_type) + if not requirements: + return None + + input_path = await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) + + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", + media_id, + server_name, + media_type, + e, + ) + return None + + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height + + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, + ) + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose + ) + + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: Dict[Tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + with self.media_storage.store_into_file(file_info) as ( + f, + fname, + finish, + ): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) + + return {"width": m_width, "height": m_height} + + async def _apply_media_retention_rules(self) -> None: + """ + Purge old local and remote media according to the media retention rules + defined in the homeserver config. + """ + # Purge remote media + if self._media_retention_remote_media_lifetime_ms is not None: + # Calculate a threshold timestamp derived from the configured lifetime. Any + # media that has not been accessed since this timestamp will be removed. + remote_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms + ) + + logger.info( + "Purging remote media last accessed before" + f" {remote_media_threshold_timestamp_ms}" + ) + + await self.delete_old_remote_media( + before_ts=remote_media_threshold_timestamp_ms + ) + + # And now do the same for local media + if self._media_retention_local_media_lifetime_ms is not None: + # This works the same as the remote media threshold + local_media_threshold_timestamp_ms = ( + self.clock.time_msec() - self._media_retention_local_media_lifetime_ms + ) + + logger.info( + "Purging local media last accessed before" + f" {local_media_threshold_timestamp_ms}" + ) + + await self.delete_old_local_media( + before_ts=local_media_threshold_timestamp_ms, + keep_profiles=True, + delete_quarantined_media=False, + delete_protected_media=False, + ) + + async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]: + old_media = await self.store.get_remote_media_ids( + before_ts, include_quarantined_media=False + ) + + deleted = 0 + + for media in old_media: + origin = media["media_origin"] + media_id = media["media_id"] + file_id = media["filesystem_id"] + key = (origin, media_id) + + logger.info("Deleting: %r", key) + + # TODO: Should we delete from the backup store + + async with self.remote_media_linearizer.queue(key): + full_path = self.filepaths.remote_media_filepath(origin, file_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r", full_path) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( + origin, file_id + ) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(origin, media_id) + deleted += 1 + + return {"deleted": deleted} + + async def delete_local_media_ids( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete the given local or remote media ID from this server + + Args: + media_id: The media ID to delete. + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + return await self._remove_local_media_from_disk(media_ids) + + async def delete_old_local_media( + self, + before_ts: int, + size_gt: int = 0, + keep_profiles: bool = True, + delete_quarantined_media: bool = False, + delete_protected_media: bool = False, + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server by size and timestamp. Removes + media files, any thumbnails and cached URLs. + + Args: + before_ts: Unix timestamp in ms. + Files that were last used before this timestamp will be deleted. + size_gt: Size of the media in bytes. Files that are larger will be deleted. + keep_profiles: Switch to delete also files that are still used in image data + (e.g user profile, room avatar). If false these files will be deleted. + delete_quarantined_media: If True, media marked as quarantined will be deleted. + delete_protected_media: If True, media marked as protected will be deleted. + + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + old_media = await self.store.get_local_media_ids( + before_ts, + size_gt, + keep_profiles, + include_quarantined_media=delete_quarantined_media, + include_protected_media=delete_protected_media, + ) + return await self._remove_local_media_from_disk(old_media) + + async def _remove_local_media_from_disk( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server. Removes media files, + any thumbnails and cached URLs. + + Args: + media_ids: List of media_id to delete + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + removed_media = [] + for media_id in media_ids: + logger.info("Deleting media with ID '%s'", media_id) + full_path = self.filepaths.local_media_filepath(media_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r: %s", full_path, e) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(self.server_name, media_id) + + await self.store.delete_url_cache((media_id,)) + await self.store.delete_url_cache_media((media_id,)) + + removed_media.append(media_id) + + return removed_media, len(removed_media) diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py new file mode 100644 index 0000000000..a7e22a91e1 --- /dev/null +++ b/synapse/media/media_storage.py @@ -0,0 +1,374 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import os +import shutil +from types import TracebackType +from typing import ( + IO, + TYPE_CHECKING, + Any, + Awaitable, + BinaryIO, + Callable, + Generator, + Optional, + Sequence, + Tuple, + Type, +) + +import attr + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer +from twisted.protocols.basic import FileSender + +import synapse +from synapse.api.errors import NotFoundError +from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock +from synapse.util.file_consumer import BackgroundFileConsumer + +from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.media.storage_provider import StorageProvider + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class MediaStorage: + """Responsible for storing/fetching files from local sources. + + Args: + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. + """ + + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): + self.hs = hs + self.reactor = hs.get_reactor() + self.local_media_directory = local_media_directory + self.filepaths = filepaths + self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() + + async def store_file(self, source: IO, file_info: FileInfo) -> str: + """Write `source` to the on disk media store, and also any other + configured storage providers + + Args: + source: A file like object that should be written + file_info: Info about the file to store + + Returns: + the file path written to in the primary media store + """ + + with self.store_into_file(file_info) as (f, fname, finish_cb): + # Write to the main repository + await self.write_to_file(source, f) + await finish_cb() + + return fname + + async def write_to_file(self, source: IO, output: IO) -> None: + """Asynchronously write the `source` to `output`.""" + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + + @contextlib.contextmanager + def store_into_file( + self, file_info: FileInfo + ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: + """Context manager used to get a file like object to write into, as + described by file_info. + + Actually yields a 3-tuple (file, fname, finish_cb), where file is a file + like object that can be written to, fname is the absolute path of file + on disk, and finish_cb is a function that returns an awaitable. + + fname can be used to read the contents from after upload, e.g. to + generate thumbnails. + + finish_cb must be called and waited on after the file has been + successfully been written to. Should not be called if there was an + error. + + Args: + file_info: Info about the file to store + + Example: + + with media_storage.store_into_file(info) as (f, fname, finish_cb): + # .. write into f ... + await finish_cb() + """ + + path = self._file_info_to_path(file_info) + fname = os.path.join(self.local_media_directory, path) + + dirname = os.path.dirname(fname) + os.makedirs(dirname, exist_ok=True) + + finished_called = [False] + + try: + with open(fname, "wb") as f: + + async def finish() -> None: + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + spam_check = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam_check != synapse.module_api.NOT_SPAM: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + # We currently ignore any additional field returned by + # the spam-check API. + raise SpamMediaException(errcode=spam_check[0]) + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + + yield f, fname, finish + except Exception as e: + try: + os.remove(fname) + except Exception: + pass + + raise e from None + + if not finished_called: + raise Exception("Finished callback not called") + + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: + """Attempts to fetch media described by file_info from the local cache + and configured storage providers. + + Args: + file_info + + Returns: + Returns a Responder if the file was found, otherwise None. + """ + paths = [self._file_info_to_path(file_info)] + + # fallback for remote thumbnails with no method in the filename + if file_info.thumbnail and file_info.server_name: + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + ) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) + + for provider in self.storage_providers: + for path in paths: + res: Any = await provider.fetch(path, file_info) + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) + + return None + + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: + """Ensures that the given file is in the local cache. Attempts to + download it from storage providers if it isn't. + + Args: + file_info + + Returns: + Full path to local file + """ + path = self._file_info_to_path(file_info) + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + return local_path + + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + + dirname = os.path.dirname(local_path) + os.makedirs(dirname, exist_ok=True) + + for provider in self.storage_providers: + res: Any = await provider.fetch(path, file_info) + if res: + with res: + consumer = BackgroundFileConsumer( + open(local_path, "wb"), self.reactor + ) + await res.write_to_consumer(consumer) + await consumer.wait() + return local_path + + raise NotFoundError() + + def _file_info_to_path(self, file_info: FileInfo) -> str: + """Converts file_info into a relative path. + + The path is suitable for storing files under a directory, e.g. used to + store files on local FS under the base media repository directory. + """ + if file_info.url_cache: + if file_info.thumbnail: + return self.filepaths.url_cache_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.url_cache_filepath_rel(file_info.file_id) + + if file_info.server_name: + if file_info.thumbnail: + return self.filepaths.remote_media_thumbnail_rel( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.remote_media_filepath_rel( + file_info.server_name, file_info.file_id + ) + + if file_info.thumbnail: + return self.filepaths.local_media_thumbnail_rel( + media_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + method=file_info.thumbnail.method, + ) + return self.filepaths.local_media_filepath_rel(file_info.file_id) + + +def _write_file_synchronously(source: IO, dest: IO) -> None: + """Write `source` to the file like `dest` synchronously. Should be called + from a thread. + + Args: + source: A file like object that's to be written + dest: A file like object to be written to + """ + source.seek(0) # Ensure we read from the start of the file + shutil.copyfileobj(source, dest) + + +class FileResponder(Responder): + """Wraps an open file that can be sent to a request. + + Args: + open_file: A file like object to be streamed ot the client, + is closed when finished streaming. + """ + + def __init__(self, open_file: IO): + self.open_file = open_file + + def write_to_consumer(self, consumer: IConsumer) -> Deferred: + return make_deferred_yieldable( + FileSender().beginFileTransfer(self.open_file, consumer) + ) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True, auto_attribs=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2**14 + + clock: Clock + path: str + + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: + """Reads the file in chunks and calls the callback with each chunk.""" + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py new file mode 100644 index 0000000000..c0eaf04be5 --- /dev/null +++ b/synapse/media/oembed.py @@ -0,0 +1,265 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import html +import logging +import urllib.parse +from typing import TYPE_CHECKING, List, Optional + +import attr + +from synapse.media.preview_html import parse_html_description +from synapse.types import JsonDict +from synapse.util import json_decoder + +if TYPE_CHECKING: + from lxml import etree + + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class OEmbedResult: + # The Open Graph result (converted from the oEmbed result). + open_graph_result: JsonDict + # The author_name of the oEmbed result + author_name: Optional[str] + # Number of milliseconds to cache the content, according to the oEmbed response. + # + # This will be None if no cache-age is provided in the oEmbed response (or + # if the oEmbed response cannot be turned into an Open Graph response). + cache_age: Optional[int] + + +class OEmbedProvider: + """ + A helper for accessing oEmbed content. + + It can be used to check if a URL should be accessed via oEmbed and for + requesting/parsing oEmbed content. + """ + + def __init__(self, hs: "HomeServer"): + self._oembed_patterns = {} + for oembed_endpoint in hs.config.oembed.oembed_patterns: + api_endpoint = oembed_endpoint.api_endpoint + + # Only JSON is supported at the moment. This could be declared in + # the formats field. Otherwise, if the endpoint ends in .xml assume + # it doesn't support JSON. + if ( + oembed_endpoint.formats is not None + and "json" not in oembed_endpoint.formats + ) or api_endpoint.endswith(".xml"): + logger.info( + "Ignoring oEmbed endpoint due to not supporting JSON: %s", + api_endpoint, + ) + continue + + # Iterate through each URL pattern and point it to the endpoint. + for pattern in oembed_endpoint.url_patterns: + self._oembed_patterns[pattern] = api_endpoint + + def get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Args: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in self._oembed_patterns.items(): + if url_pattern.fullmatch(url): + # TODO Specify max height / width. + + # Note that only the JSON format is supported, some endpoints want + # this in the URL, others want it as an argument. + endpoint = endpoint.replace("{format}", "json") + + args = {"url": url, "format": "json"} + query_str = urllib.parse.urlencode(args, True) + return f"{endpoint}?{query_str}" + + # No match. + return None + + def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]: + """ + Search an HTML document for oEmbed autodiscovery information. + + Args: + tree: The parsed HTML body. + + Returns: + The URL to use for oEmbed information, or None if no URL was found. + """ + # Search for link elements with the proper rel and type attributes. + for tag in tree.xpath( + "//link[@rel='alternate'][@type='application/json+oembed']" + ): + if "href" in tag.attrib: + return tag.attrib["href"] + + # Some providers (e.g. Flickr) use alternative instead of alternate. + for tag in tree.xpath( + "//link[@rel='alternative'][@type='application/json+oembed']" + ): + if "href" in tag.attrib: + return tag.attrib["href"] + + return None + + def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult: + """ + Parse the oEmbed response into an Open Graph response. + + Args: + url: The URL which is being previewed (not the one which was + requested). + raw_body: The oEmbed response as JSON encoded as bytes. + + Returns: + json-encoded Open Graph data + """ + + try: + # oEmbed responses *must* be UTF-8 according to the spec. + oembed = json_decoder.decode(raw_body.decode("utf-8")) + except ValueError: + return OEmbedResult({}, None, None) + + # The version is a required string field, but not always provided, + # or sometimes provided as a float. Be lenient. + oembed_version = oembed.get("version", "1.0") + if oembed_version != "1.0" and oembed_version != 1: + return OEmbedResult({}, None, None) + + # Attempt to parse the cache age, if possible. + try: + cache_age = int(oembed.get("cache_age")) * 1000 + except (TypeError, ValueError): + # If the cache age cannot be parsed (e.g. wrong type or invalid + # string), ignore it. + cache_age = None + + # The oEmbed response converted to Open Graph. + open_graph_response: JsonDict = {"og:url": url} + + title = oembed.get("title") + if title and isinstance(title, str): + # A common WordPress plug-in seems to incorrectly escape entities + # in the oEmbed response. + open_graph_response["og:title"] = html.unescape(title) + + author_name = oembed.get("author_name") + if not isinstance(author_name, str): + author_name = None + + # Use the provider name and as the site. + provider_name = oembed.get("provider_name") + if provider_name and isinstance(provider_name, str): + open_graph_response["og:site_name"] = provider_name + + # If a thumbnail exists, use it. Note that dimensions will be calculated later. + thumbnail_url = oembed.get("thumbnail_url") + if thumbnail_url and isinstance(thumbnail_url, str): + open_graph_response["og:image"] = thumbnail_url + + # Process each type separately. + oembed_type = oembed.get("type") + if oembed_type == "rich": + html_str = oembed.get("html") + if isinstance(html_str, str): + calc_description_and_urls(open_graph_response, html_str) + + elif oembed_type == "photo": + # If this is a photo, use the full image, not the thumbnail. + url = oembed.get("url") + if url and isinstance(url, str): + open_graph_response["og:image"] = url + + elif oembed_type == "video": + open_graph_response["og:type"] = "video.other" + html_str = oembed.get("html") + if html_str and isinstance(html_str, str): + calc_description_and_urls(open_graph_response, oembed["html"]) + for size in ("width", "height"): + val = oembed.get(size) + if type(val) is int: + open_graph_response[f"og:video:{size}"] = val + + elif oembed_type == "link": + open_graph_response["og:type"] = "website" + + else: + logger.warning("Unknown oEmbed type: %s", oembed_type) + + return OEmbedResult(open_graph_response, author_name, cache_age) + + +def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: + results = [] + for tag in tree.xpath("//*/" + tag_name): + if "src" in tag.attrib: + results.append(tag.attrib["src"]) + return results + + +def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None: + """ + Calculate description for an HTML document. + + This uses lxml to convert the HTML document into plaintext. If errors + occur during processing of the document, an empty response is returned. + + Args: + open_graph_response: The current Open Graph summary. This is updated with additional fields. + html_body: The HTML document, as bytes. + + Returns: + The summary + """ + # If there's no body, nothing useful is going to be found. + if not html_body: + return + + from lxml import etree + + # Create an HTML parser. If this fails, log and return no metadata. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(html_body, parser) + + # The data was successfully parsed, but no tree was found. + if tree is None: + return + + # Attempt to find interesting URLs (images, videos, embeds). + if "og:image" not in open_graph_response: + image_urls = _fetch_urls(tree, "img") + if image_urls: + open_graph_response["og:image"] = image_urls[0] + + video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed") + if video_urls: + open_graph_response["og:video"] = video_urls[0] + + description = parse_html_description(tree) + if description: + open_graph_response["og:description"] = description diff --git a/synapse/media/preview_html.py b/synapse/media/preview_html.py new file mode 100644 index 0000000000..516d0434f0 --- /dev/null +++ b/synapse/media/preview_html.py @@ -0,0 +1,501 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import codecs +import logging +import re +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Union, +) + +if TYPE_CHECKING: + from lxml import etree + +logger = logging.getLogger(__name__) + +_charset_match = re.compile( + rb'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) +_xml_encoding_match = re.compile( + rb'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I +) +_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) + +# Certain elements aren't meant for display. +ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} + + +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + +def _get_html_media_encodings( + body: bytes, content_type: Optional[str] +) -> Iterable[str]: + """ + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. + + The precedence used for finding a character encoding is: + + 1. tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Check if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding + + # TODO Support + + # Check if it has an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding + + # Finally, fallback to UTF-8, then windows-1252. + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback + + +def decode_body( + body: bytes, uri: str, content_type: Optional[str] = None +) -> Optional["etree.Element"]: + """ + This uses lxml to parse the HTML document. + + Args: + body: The HTML document, as bytes. + uri: The URI used to download the body. + content_type: The Content-Type header. + + Returns: + The parsed HTML body, or None if an error occurred during processed. + """ + # If there's no body, nothing useful is going to be found. + if not body: + return None + + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. + for encoding in _get_html_media_encodings(body, content_type): + try: + body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + + from lxml import etree + + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding=encoding) + + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body, parser) + + +def _get_meta_tags( + tree: "etree.Element", + property: str, + prefix: str, + property_mapper: Optional[Callable[[str], Optional[str]]] = None, +) -> Dict[str, Optional[str]]: + """ + Search for meta tags prefixed with a particular string. + + Args: + tree: The parsed HTML document. + property: The name of the property which contains the tag name, e.g. + "property" for Open Graph. + prefix: The prefix on the property to search for, e.g. "og" for Open Graph. + property_mapper: An optional callable to map the property to the Open Graph + form. Can return None for a key to ignore that key. + + Returns: + A map of tag name to value. + """ + results: Dict[str, Optional[str]] = {} + for tag in tree.xpath( + f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]" + ): + # if we've got more than 50 tags, someone is taking the piss + if len(results) >= 50: + logger.warning( + "Skipping parsing of Open Graph for page with too many '%s:' tags", + prefix, + ) + return {} + + key = tag.attrib[property] + if property_mapper: + key = property_mapper(key) + # None is a special value used to ignore a value. + if key is None: + continue + + results[key] = tag.attrib["content"] + + return results + + +def _map_twitter_to_open_graph(key: str) -> Optional[str]: + """ + Map a Twitter card property to the analogous Open Graph property. + + Args: + key: The Twitter card property (starts with "twitter:"). + + Returns: + The Open Graph property (starts with "og:") or None to have this property + be ignored. + """ + # Twitter card properties with no analogous Open Graph property. + if key == "twitter:card" or key == "twitter:creator": + return None + if key == "twitter:site": + return "og:site_name" + # Otherwise, swap twitter to og. + return "og" + key[7:] + + +def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: + """ + Parse the HTML document into an Open Graph response. + + This uses lxml to search the HTML document for Open Graph data (or + synthesizes it from the document). + + Args: + tree: The parsed HTML document. + + Returns: + The Open Graph response as a dictionary. + """ + + # Search for Open Graph (og:) meta tags, e.g.: + # + # "og:type" : "video", + # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw", + # "og:site_name" : "YouTube", + # "og:video:type" : "application/x-shockwave-flash", + # "og:description" : "Fun stuff happening here", + # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon", + # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg", + # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1", + # "og:video:width" : "1280" + # "og:video:height" : "720", + # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", + + og = _get_meta_tags(tree, "property", "og") + + # TODO: Search for properties specific to the different Open Graph types, + # such as article: meta tags, e.g.: + # + # "article:publisher" : "https://www.facebook.com/thethudonline" /> + # "article:author" content="https://www.facebook.com/thethudonline" /> + # "article:tag" content="baby" /> + # "article:section" content="Breaking News" /> + # "article:published_time" content="2016-03-31T19:58:24+00:00" /> + # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> + + # Search for Twitter Card (twitter:) meta tags, e.g.: + # + # "twitter:site" : "@matrixdotorg" + # "twitter:creator" : "@matrixdotorg" + # + # Twitter cards tags also duplicate Open Graph tags. + # + # See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started + twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph) + # Merge the Twitter values with the Open Graph values, but do not overwrite + # information from Open Graph tags. + for key, value in twitter.items(): + if key not in og: + og[key] = value + + if "og:title" not in og: + # Attempt to find a title from the title tag, or the biggest header on the page. + title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") + if title: + og["og:title"] = title[0].strip() + else: + og["og:title"] = None + + if "og:image" not in og: + meta_image = tree.xpath( + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" + ) + # If a meta image is found, use it. + if meta_image: + og["og:image"] = meta_image[0] + else: + # Try to find images which are larger than 10px by 10px. + # + # TODO: consider inlined CSS styles as well as width & height attribs + images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") + images = sorted( + images, + key=lambda i: ( + -1 * float(i.attrib["width"]) * float(i.attrib["height"]) + ), + ) + # If no images were found, try to find *any* images. + if not images: + images = tree.xpath("//img[@src][1]") + if images: + og["og:image"] = images[0].attrib["src"] + + # Finally, fallback to the favicon if nothing else. + else: + favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") + if favicons: + og["og:image"] = favicons[0] + + if "og:description" not in og: + # Check the first meta description tag for content. + meta_description = tree.xpath( + "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" + ) + # If a meta description is found with content, use it. + if meta_description: + og["og:description"] = meta_description[0] + else: + og["og:description"] = parse_html_description(tree) + elif og["og:description"]: + # This must be a non-empty string at this point. + assert isinstance(og["og:description"], str) + og["og:description"] = summarize_paragraphs([og["og:description"]]) + + # TODO: delete the url downloads to stop diskfilling, + # as we only ever cared about its OG + return og + + +def parse_html_description(tree: "etree.Element") -> Optional[str]: + """ + Calculate a text description based on an HTML document. + + Grabs any text nodes which are inside the tag, unless they are within + an HTML5 semantic markup tag (
,