From a699c044b6598954ec74322923dd6ec698cd272d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 27 Oct 2020 18:42:46 +0000 Subject: Abstract code for stripping room state into a separate method (#8671) This is a requirement for [knocking](https://github.com/matrix-org/synapse/pull/6739), and is abstracting some code that was originally used by the invite flow. I'm separating it out into this PR as it's a fairly contained change. For a bit of context: when you invite a user to a room, you send them [stripped state events](https://matrix.org/docs/spec/server_server/unstable#put-matrix-federation-v2-invite-roomid-eventid) as part of `invite_room_state`. This is so that their client can display useful information such as the room name and avatar. The same requirement applies to knocking, as it would be nice for clients to be able to display a list of rooms you've knocked on - room name and avatar included. The reason we're sending membership events down as well is in the case that you are invited to a room that does not have an avatar or name set. In that case, the client should use the displayname/avatar of the inviter. That information is located in the inviter's membership event. This is optional as knocks don't really have any user in the room to link up to. When you knock on a room, your knock is sent by you and inserted into the room. It wouldn't *really* make sense to show the avatar of a random user - plus it'd be a data leak. So I've opted not to send membership events to the client here. The UX on the client for when you knock on a room without a name/avatar is a separate problem. In essence this is just moving some inline code to a reusable store method. --- synapse/storage/databases/main/events_worker.py | 54 ++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6e7f16f39c..cd1f31aa62 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -31,6 +31,7 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import ( @@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore): return event_map + async def get_stripped_room_state_from_event_context( + self, + context: EventContext, + state_types_to_include: List[EventTypes], + membership_user_id: Optional[str], + ) -> List[JsonDict]: + """ + Retrieve the stripped state from a room, given an event context to retrieve state + from as well as the state types to include. Optionally, include the membership + events from a specific user. + + "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys + are included from each state event. + + Args: + context: The event context to retrieve state of the room from. + state_types_to_include: The type of state events to include. + membership_user_id: An optional user ID to include the stripped membership state + events of. This is useful when generating the stripped state of a room for + invites. We want to send membership events of the inviter, so that the + invitee can display the inviter's profile information if the room lacks any. + + Returns: + A list of dictionaries, each representing a stripped state event from the room. + """ + current_state_ids = await context.get_current_state_ids() + + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + + # The state to include + state_to_include_ids = [ + e_id + for k, e_id in current_state_ids.items() + if k[0] in state_types_to_include + or (membership_user_id and k == (EventTypes.Member, membership_user_id)) + ] + + state_to_include = await self.get_events(state_to_include_ids) + + return [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "sender": e.sender, + } + for e in state_to_include.values() + ] + def _do_fetch(self, conn): """Takes a database connection and waits for requests for events from the _event_fetch_list queue. -- cgit 1.5.1 From a6ea1a957e8e38ca3f98d4da32ee49a40fcb4807 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Oct 2020 12:11:45 +0000 Subject: Don't pull event from DB when handling replication traffic. (#8669) I was trying to make it so that we didn't have to start a background task when handling RDATA, but that is a bigger job (due to all the code in `generic_worker`). However I still think not pulling the event from the DB may help reduce some DB usage due to replication, even if most workers will simply go and pull that event from the DB later anyway. Co-authored-by: Patrick Cloke --- changelog.d/8669.misc | 1 + synapse/notifier.py | 68 ++++++++++++++++++++----- synapse/replication/tcp/client.py | 20 +++++--- synapse/replication/tcp/streams/events.py | 21 +++++--- synapse/storage/databases/main/events_worker.py | 8 ++- 5 files changed, 87 insertions(+), 31 deletions(-) create mode 100644 changelog.d/8669.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8669.misc b/changelog.d/8669.misc new file mode 100644 index 0000000000..5228105cd3 --- /dev/null +++ b/changelog.d/8669.misc @@ -0,0 +1 @@ +Don't pull event from DB when handling replication traffic. diff --git a/synapse/notifier.py b/synapse/notifier.py index eb56b26f21..a17352ef46 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -28,6 +28,7 @@ from typing import ( Union, ) +import attr from prometheus_client import Counter from twisted.internet import defer @@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): return bool(self.events) +@attr.s(slots=True, frozen=True) +class _PendingRoomEventEntry: + event_pos = attr.ib(type=PersistedEventPosition) + extra_users = attr.ib(type=Collection[UserID]) + + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + + class Notifier: """ This class is responsible for notifying any listeners when there are new events available for it. @@ -190,9 +202,7 @@ class Notifier: self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = ( - [] - ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] + self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -255,7 +265,29 @@ class Notifier: max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): - """ Used by handlers to inform the notifier something has happened + """Unwraps event and calls `on_new_room_event_args`. + """ + self.on_new_room_event_args( + event_pos=event_pos, + room_id=event.room_id, + event_type=event.type, + state_key=event.get("state_key"), + membership=event.content.get("membership"), + max_room_stream_token=max_room_stream_token, + extra_users=extra_users, + ) + + def on_new_room_event_args( + self, + room_id: str, + event_type: str, + state_key: Optional[str], + membership: Optional[str], + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, + extra_users: Collection[UserID] = [], + ): + """Used by handlers to inform the notifier something has happened in the room, room event wise. This triggers the notifier to wake up any listeners that are @@ -266,7 +298,16 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((event_pos, event, extra_users)) + self.pending_new_room_events.append( + _PendingRoomEventEntry( + event_pos=event_pos, + extra_users=extra_users, + room_id=room_id, + type=event_type, + state_key=state_key, + membership=membership, + ) + ) self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() @@ -284,18 +325,19 @@ class Notifier: users = set() # type: Set[UserID] rooms = set() # type: Set[str] - for event_pos, event, extra_users in pending: - if event_pos.persisted_after(max_room_stream_token): - self.pending_new_room_events.append((event_pos, event, extra_users)) + for entry in pending: + if entry.event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append(entry) else: if ( - event.type == EventTypes.Member - and event.membership == Membership.JOIN + entry.type == EventTypes.Member + and entry.membership == Membership.JOIN + and entry.state_key ): - self._user_joined_room(event.state_key, event.room_id) + self._user_joined_room(entry.state_key, entry.room_id) - users.update(extra_users) - rooms.add(event.room_id) + users.update(entry.extra_users) + rooms.add(entry.room_id) if users or rooms: self.on_new_event( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e27ee216f0..2618eb1e53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -141,21 +141,25 @@ class ReplicationDataHandler: if row.type != EventsStreamEventRow.TypeId: continue assert isinstance(row, EventsStreamRow) + assert isinstance(row.data, EventsStreamEventRow) - event = await self.store.get_event( - row.data.event_id, allow_rejected=True - ) - if event.rejected_reason: + if row.data.rejected: continue extra_users = () # type: Tuple[UserID, ...] - if event.type == EventTypes.Member: - extra_users = (UserID.from_string(event.state_key),) + if row.data.type == EventTypes.Member and row.data.state_key: + extra_users = (UserID.from_string(row.data.state_key),) max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - self.notifier.on_new_room_event( - event, event_pos, max_token, extra_users + self.notifier.on_new_room_event_args( + event_pos=event_pos, + max_room_stream_token=max_token, + extra_users=extra_users, + room_id=row.data.room_id, + event_type=row.data.type, + state_key=row.data.state_key, + membership=row.data.membership, ) # Notify any waiting deferreds. The list is ordered by position so we diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 82e9e0d64e..86a62b71eb 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -15,12 +15,15 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import List, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import attr from ._base import Stream, StreamUpdateResult, Token +if TYPE_CHECKING: + from synapse.server import HomeServer + """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' @@ -81,12 +84,14 @@ class BaseEventsStreamRow: class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional - relates_to = attr.ib() # str, optional + event_id = attr.ib(type=str) + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + redacts = attr.ib(type=Optional[str]) + relates_to = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + rejected = attr.ib(type=bool) @attr.s(slots=True, frozen=True) @@ -113,7 +118,7 @@ class EventsStream(Stream): NAME = "events" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cd1f31aa62..5ae263827d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1117,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " AND instance_name = ?" " ORDER BY stream_ordering ASC" @@ -1152,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " AND out.instance_name = ?" -- cgit 1.5.1 From 31d721fbf6655080235003b5576110d477fa2353 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 28 Oct 2020 11:12:21 -0400 Subject: Add type hints to application services. (#8655) --- changelog.d/8655.misc | 1 + mypy.ini | 4 ++ synapse/handlers/appservice.py | 75 +++++++++++---------- synapse/handlers/auth.py | 23 +++++-- synapse/storage/databases/main/appservice.py | 98 +++++++++++++++++----------- 5 files changed, 122 insertions(+), 79 deletions(-) create mode 100644 changelog.d/8655.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8655.misc b/changelog.d/8655.misc new file mode 100644 index 0000000000..b588bdd3e2 --- /dev/null +++ b/changelog.d/8655.misc @@ -0,0 +1 @@ +Add more type hints to the application services code. diff --git a/mypy.ini b/mypy.ini index 1fbd8decf8..1ece2ba082 100644 --- a/mypy.ini +++ b/mypy.ini @@ -57,6 +57,7 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, @@ -82,6 +83,9 @@ ignore_missing_imports = True [mypy-zope] ignore_missing_imports = True +[mypy-bcrypt] +ignore_missing_imports = True + [mypy-constantly] ignore_missing_imports = True diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3ed29a2c16..9fc8444228 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from prometheus_client import Counter @@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.types import Collection, JsonDict, RoomStreamToken, UserID +from synapse.storage.databases.main.directory import RoomAliasMapping +from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") class ApplicationServicesHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -247,7 +250,9 @@ class ApplicationServicesHandler: service, "presence", new_token ) - async def _handle_typing(self, service: ApplicationService, new_token: int): + async def _handle_typing( + self, service: ApplicationService, new_token: int + ) -> List[JsonDict]: typing_source = self.event_sources.sources["typing"] # Get the typing events from just before current typing, _ = await typing_source.get_new_events_as( @@ -259,7 +264,7 @@ class ApplicationServicesHandler: ) return typing - async def _handle_receipts(self, service: ApplicationService): + async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) @@ -271,7 +276,7 @@ class ApplicationServicesHandler: async def _handle_presence( self, service: ApplicationService, users: Collection[Union[str, UserID]] - ): + ) -> List[JsonDict]: events = [] # type: List[JsonDict] presence_source = self.event_sources.sources["presence"] from_key = await self.store.get_type_stream_id_for_appservice( @@ -301,11 +306,11 @@ class ApplicationServicesHandler: return events - async def query_user_exists(self, user_id): + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. Args: - user_id(str): The user to query if they exist on any AS. + user_id: The user to query if they exist on any AS. Returns: True if this user exists on at least one application service. """ @@ -316,11 +321,13 @@ class ApplicationServicesHandler: return True return False - async def query_room_alias_exists(self, room_alias): + async def query_room_alias_exists( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: """Check if an application service knows this room alias exists. Args: - room_alias(RoomAlias): The room alias to query. + room_alias: The room alias to query. Returns: namedtuple: with keys "room_id" and "servers" or None if no association can be found. @@ -336,10 +343,13 @@ class ApplicationServicesHandler: ) if is_known_alias: # the alias exists now so don't query more ASes. - result = await self.store.get_association_from_room_alias(room_alias) - return result + return await self.store.get_association_from_room_alias(room_alias) + + return None - async def query_3pe(self, kind, protocol, fields): + async def query_3pe( + self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]] + ) -> List[JsonDict]: services = self._get_services_for_3pn(protocol) results = await make_deferred_yieldable( @@ -361,7 +371,9 @@ class ApplicationServicesHandler: return ret - async def get_3pe_protocols(self, only_protocol=None): + async def get_3pe_protocols( + self, only_protocol: Optional[str] = None + ) -> Dict[str, JsonDict]: services = self.store.get_app_services() protocols = {} # type: Dict[str, List[JsonDict]] @@ -379,7 +391,7 @@ class ApplicationServicesHandler: if info is not None: protocols[p].append(info) - def _merge_instances(infos): + def _merge_instances(infos: List[JsonDict]) -> JsonDict: if not infos: return {} @@ -394,19 +406,17 @@ class ApplicationServicesHandler: return combined - for p in protocols.keys(): - protocols[p] = _merge_instances(protocols[p]) + return {p: _merge_instances(protocols[p]) for p in protocols.keys()} - return protocols - - async def _get_services_for_event(self, event): + async def _get_services_for_event( + self, event: EventBase + ) -> List[ApplicationService]: """Retrieve a list of application services interested in this event. Args: - event(Event): The event to check. Can be None if alias_list is not. + event: The event to check. Can be None if alias_list is not. Returns: - list: A list of services interested in this - event based on the service regex. + A list of services interested in this event based on the service regex. """ services = self.store.get_app_services() @@ -420,17 +430,15 @@ class ApplicationServicesHandler: return interested_list - def _get_services_for_user(self, user_id): + def _get_services_for_user(self, user_id: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if (s.is_interested_in_user(user_id))] - return interested_list + return [s for s in services if (s.is_interested_in_user(user_id))] - def _get_services_for_3pn(self, protocol): + def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] - return interested_list + return [s for s in services if s.is_interested_in_protocol(protocol)] - async def _is_unknown_user(self, user_id): + async def _is_unknown_user(self, user_id: str) -> bool: if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. @@ -445,9 +453,8 @@ class ApplicationServicesHandler: service_list = [s for s in services if s.sender == user_id] return len(service_list) == 0 - async def _check_user_exists(self, user_id): + async def _check_user_exists(self, user_id: str) -> bool: unknown_user = await self._is_unknown_user(user_id) if unknown_user: - exists = await self.query_user_exists(user_id) - return exists + return await self.query_user_exists(user_id) return True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dd14ab69d7..276594f3d9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,10 +18,20 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import attr -import bcrypt # type: ignore[import] +import bcrypt import pymacaroons from synapse.api.constants import LoginType @@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -149,11 +162,7 @@ class SsoLoginExtraAttributes: class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 637a938bac..26eef6eb61 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -15,21 +15,31 @@ # limitations under the License. import logging import re -from typing import List +from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple -from synapse.appservice import ApplicationService, AppServiceTransaction +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + AppServiceTransaction, +) from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) -def _make_exclusive_regex(services_cache): +def _make_exclusive_regex( + services_cache: List[ApplicationService], +) -> Optional[Pattern]: # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ @@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache): ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) + exclusive_user_pattern = re.compile( + exclusive_user_regex + ) # type: Optional[Pattern] else: # We handle this case specially otherwise the constructed regex # will always match - exclusive_user_regex = None + exclusive_user_pattern = None - return exclusive_user_regex + return exclusive_user_pattern class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): def get_app_services(self): return self.services_cache - def get_if_app_services_interested_in_user(self, user_id): + def get_if_app_services_interested_in_user(self, user_id: str) -> bool: """Check if the user is one associated with an app service (exclusively) """ if self.exclusive_user_regex: @@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): else: return False - def get_app_service_by_user_id(self, user_id): + def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]: """Retrieve an application service from their user ID. All application services have associated with them a particular user ID. @@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore): a user ID to an application service. Args: - user_id(str): The user ID to see if it is an application service. + user_id: The user ID to see if it is an application service. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.sender == user_id: return service return None - def get_app_service_by_token(self, token): + def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]: """Get the application service with the given appservice token. Args: - token (str): The application service token. + token: The application service token. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.token == token: return service return None - def get_app_service_by_id(self, as_id): + def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: """Get the application service with the given appservice ID. Args: - as_id (str): The application service ID. + as_id: The application service ID. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.id == as_id: @@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - async def get_appservices_by_state(self, state): + async def get_appservices_by_state( + self, state: ApplicationServiceState + ) -> List[ApplicationService]: """Get a list of application services based on their state. Args: - state(ApplicationServiceState): The state to filter on. + state: The state to filter on. Returns: A list of ApplicationServices, which may be empty. """ @@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - async def get_appservice_state(self, service): + async def get_appservice_state( + self, service: ApplicationService + ) -> Optional[ApplicationServiceState]: """Get the application service state. Args: - service(ApplicationService): The service whose state to set. + service: The service whose state to set. Returns: - An ApplicationServiceState. + An ApplicationServiceState or none. """ result = await self.db_pool.simple_select_one( "application_services_state", @@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - async def set_appservice_state(self, service, state) -> None: + async def set_appservice_state( + self, service: ApplicationService, state: ApplicationServiceState + ) -> None: """Set the application service state. Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. + service: The service whose state to set. + state: The connectivity state to apply. """ await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} @@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore( "create_appservice_txn", _create_appservice_txn ) - async def complete_appservice_txn(self, txn_id, service) -> None: + async def complete_appservice_txn( + self, txn_id: int, service: ApplicationService + ) -> None: """Completes an application service transaction. Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. + txn_id: The transaction ID being completed. + service: The application service which was sent this transaction. """ txn_id = int(txn_id) @@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore( # has probably missed some events), so whine loudly but still continue, # since it shouldn't fail completion of the transaction. last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: + if (txn_id + 1) != txn_id: logger.error( "appservice: Completing a transaction which has an ID > 1 from " "the last ID sent to this AS. We've either dropped events or " @@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - async def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. + async def get_oldest_unsent_txn( + self, service: ApplicationService + ) -> Optional[AppServiceTransaction]: + """Get the oldest transaction which has not been sent for this service. Args: - service(ApplicationService): The app service to get the oldest txn. + service: The app service to get the oldest txn. Returns: An AppServiceTransaction or None. """ @@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore( service=service, id=entry["txn_id"], events=events, ephemeral=[] ) - def _get_last_txn(self, txn, service_id): + def _get_last_txn(self, txn, service_id: Optional[str]) -> int: txn.execute( "SELECT last_txn FROM application_services_state WHERE as_id=?", (service_id,), @@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos) -> None: + async def set_appservice_last_pos(self, pos: int) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) @@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - async def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice( + self, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): @@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore( ) async def set_type_stream_id_for_appservice( - self, service: ApplicationService, type: str, pos: int + self, service: ApplicationService, type: str, pos: Optional[int] ) -> None: if type not in ("read_receipt", "presence"): raise ValueError( -- cgit 1.5.1 From b6ca69e4f109c745f022885ecb8aa86255f84ecf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 28 Oct 2020 15:51:15 +0000 Subject: Remove frozendict_json_encoder and support frozendicts everywhere Not being able to serialise `frozendicts` is fragile, and it's annoying to have to think about which serialiser you want. There's no real downside to supporting frozendicts, so let's just have one json encoder. --- changelog.d/8678.bugfix | 1 + synapse/handlers/message.py | 5 ++--- synapse/http/server.py | 2 +- synapse/storage/databases/main/censor_events.py | 6 +++--- synapse/storage/databases/main/events.py | 10 ++++------ synapse/util/__init__.py | 24 +++++++++++++++++++++--- synapse/util/frozenutils.py | 22 ---------------------- 7 files changed, 32 insertions(+), 38 deletions(-) create mode 100644 changelog.d/8678.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/8678.bugfix b/changelog.d/8678.bugfix new file mode 100644 index 0000000000..0508d8f109 --- /dev/null +++ b/changelog.d/8678.bugfix @@ -0,0 +1 @@ +Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6855c60ea..fb0a04e9a7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester -from synapse.util import json_decoder +from synapse.util import json_decoder, json_encoder from synapse.util.async_helpers import Linearizer -from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -928,7 +927,7 @@ class EventCreationHandler: # Ensure that we can round trip before trying to persist in db try: - dump = frozendict_json_encoder.encode(event.content) + dump = json_encoder.encode(event.content) json_decoder.decode(dump) except Exception: logger.exception("Failed to encode content: %r", event.content) diff --git a/synapse/http/server.py b/synapse/http/server.py index 00b98af3d4..29dd604f85 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -620,7 +620,7 @@ def respond_with_json( if pretty_print: encoder = iterencode_pretty_printed_json else: - if canonical_json or synapse.events.USE_FROZEN_DICTS: + if canonical_json: encoder = iterencode_canonical_json else: encoder = _encode_json_bytes diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 849bd5ba7a..3e26d5ba87 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util import json_encoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase and original_event.internal_metadata.is_redacted() ): # Redaction was allowed - pruned_json = frozendict_json_encoder.encode( + pruned_json = json_encoder.encode( prune_event_dict( original_event.room_version, original_event.get_dict() ) @@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase return # Prune the event's dict then convert it to JSON. - pruned_json = frozendict_json_encoder.encode( + pruned_json = json_encoder.encode( prune_event_dict(event.room_version, event.get_dict()) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 87808c1483..90fb1a1f00 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import StateMap, get_domain_from_id -from synapse.util.frozenutils import frozendict_json_encoder +from synapse.util import json_encoder from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -769,9 +769,7 @@ class PersistEventsStore: logger.exception("") raise - metadata_json = frozendict_json_encoder.encode( - event.internal_metadata.get_dict() - ) + metadata_json = json_encoder.encode(event.internal_metadata.get_dict()) sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -826,10 +824,10 @@ class PersistEventsStore: { "event_id": event.event_id, "room_id": event.room_id, - "internal_metadata": frozendict_json_encoder.encode( + "internal_metadata": json_encoder.encode( event.internal_metadata.get_dict() ), - "json": frozendict_json_encoder.encode(event_dict(event)), + "json": json_encoder.encode(event_dict(event)), "format_version": event.format_version, } for event, _ in events_and_contexts diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index d55b93d763..517686f0a6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -18,6 +18,7 @@ import logging import re import attr +from frozendict import frozendict from twisted.internet import defer, task @@ -31,9 +32,26 @@ def _reject_invalid_json(val): raise ValueError("Invalid JSON value: '%s'" % val) -# Create a custom encoder to reduce the whitespace produced by JSON encoding and -# ensure that valid JSON is produced. -json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":")) +def _handle_frozendict(obj): + """Helper for json_encoder. Makes frozendicts serializable by returning + the underlying dict + """ + if type(obj) is frozendict: + # fishing the protected dict out of the object is a bit nasty, + # but we don't really want the overhead of copying the dict. + return obj._dict + raise TypeError( + "Object of type %s is not JSON serializable" % obj.__class__.__name__ + ) + + +# A custom JSON encoder which: +# * handles frozendicts +# * produces valid JSON (no NaNs etc) +# * reduces redundant whitespace +json_encoder = json.JSONEncoder( + allow_nan=False, separators=(",", ":"), default=_handle_frozendict +) # Create a custom decoder to reject Python extensions to JSON. json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index bf094c9386..5f7a6dd1d3 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - from frozendict import frozendict @@ -49,23 +47,3 @@ def unfreeze(o): pass return o - - -def _handle_frozendict(obj): - """Helper for EventEncoder. Makes frozendicts serializable by returning - the underlying dict - """ - if type(obj) is frozendict: - # fishing the protected dict out of the object is a bit nasty, - # but we don't really want the overhead of copying the dict. - return obj._dict - raise TypeError( - "Object of type %s is not JSON serializable" % obj.__class__.__name__ - ) - - -# A JSONEncoder which is capable of encoding frozendicts without barfing. -# Additionally reduce the whitespace produced by JSON encoding. -frozendict_json_encoder = json.JSONEncoder( - allow_nan=False, separators=(",", ":"), default=_handle_frozendict, -) -- cgit 1.5.1 From f21e24ffc22a5eb01f242f47fa30979321cf20fc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Oct 2020 15:58:44 +0000 Subject: Add ability for access tokens to belong to one user but grant access to another user. (#8616) We do it this way round so that only the "owner" can delete the access token (i.e. `/logout/all` by the "owner" also deletes that token, but `/logout/all` by the "target user" doesn't). A future PR will add an API for creating such a token. When the target user and authenticated entity are different the `Processed request` log line will be logged with a: `{@admin:server as @bob:server} ...`. I'm not convinced by that format (especially since it adds spaces in there, making it harder to use `cut -d ' '` to chop off the start of log lines). Suggestions welcome. --- changelog.d/8616.misc | 1 + synapse/api/auth.py | 113 +++++++++------------ synapse/appservice/__init__.py | 4 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/auth.py | 8 +- synapse/handlers/register.py | 7 +- synapse/http/site.py | 30 ++++-- synapse/replication/http/membership.py | 6 +- synapse/replication/http/send_event.py | 3 +- synapse/storage/databases/main/registration.py | 48 +++++++-- .../main/schema/delta/58/22puppet_token.sql | 17 ++++ synapse/types.py | 33 ++++-- tests/api/test_auth.py | 29 +++--- tests/api/test_ratelimiting.py | 4 +- tests/appservice/test_appservice.py | 1 + tests/handlers/test_device.py | 2 +- tests/handlers/test_message.py | 2 +- tests/push/test_email.py | 2 +- tests/push/test_http.py | 10 +- tests/replication/test_pusher_shard.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 1 + tests/storage/test_registration.py | 10 +- 22 files changed, 197 insertions(+), 138 deletions(-) create mode 100644 changelog.d/8616.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/22puppet_token.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/8616.misc b/changelog.d/8616.misc new file mode 100644 index 0000000000..385b14063e --- /dev/null +++ b/changelog.d/8616.misc @@ -0,0 +1 @@ +Change schema to support access tokens belonging to one user but granting access to another. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 526cb58c5f..bfcaf68b2a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging import opentracing as opentracing +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -190,10 +191,6 @@ class Auth: user_id, app_service = await self._get_appservice_user_id(request) if user_id: - request.authenticated_entity = user_id - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( user_id=user_id, @@ -203,31 +200,38 @@ class Auth: device_id="dummy-device", # stubbed ) - return synapse.types.create_requester(user_id, app_service=app_service) + requester = synapse.types.create_requester( + user_id, app_service=app_service + ) + + request.requester = user_id + opentracing.set_tag("authenticated_entity", user_id) + opentracing.set_tag("user_id", user_id) + opentracing.set_tag("appservice_id", app_service.id) + + return requester user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) - user = user_info["user"] - token_id = user_info["token_id"] - is_guest = user_info["is_guest"] - shadow_banned = user_info["shadow_banned"] + token_id = user_info.token_id + is_guest = user_info.is_guest + shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: - user_id = user.to_string() - if await self.store.is_account_expired(user_id, self.clock.time_msec()): + if await self.store.is_account_expired( + user_info.user_id, self.clock.time_msec() + ): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) - # device_id may not be present if get_user_by_access_token has been - # stubbed out. - device_id = user_info.get("device_id") + device_id = user_info.device_id - if user and access_token and ip_addr: + if access_token and ip_addr: await self.store.insert_client_ip( - user_id=user.to_string(), + user_id=user_info.token_owner, access_token=access_token, ip=ip_addr, user_agent=user_agent, @@ -241,19 +245,23 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - request.authenticated_entity = user.to_string() - opentracing.set_tag("authenticated_entity", user.to_string()) - if device_id: - opentracing.set_tag("device_id", device_id) - - return synapse.types.create_requester( - user, + requester = synapse.types.create_requester( + user_info.user_id, token_id, is_guest, shadow_banned, device_id, app_service=app_service, + authenticated_entity=user_info.token_owner, ) + + request.requester = requester + opentracing.set_tag("authenticated_entity", user_info.token_owner) + opentracing.set_tag("user_id", user_info.user_id) + if device_id: + opentracing.set_tag("device_id", device_id) + + return requester except KeyError: raise MissingClientTokenError() @@ -284,7 +292,7 @@ class Auth: async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ) -> dict: + ) -> TokenLookupResult: """ Validate access token and get user_id from it Args: @@ -293,13 +301,7 @@ class Auth: allow this allow_expired: If False, raises an InvalidClientTokenError if the token is expired - Returns: - dict that includes: - `user` (UserID) - `is_guest` (bool) - `shadow_banned` (bool) - `token_id` (int|None): access token id. May be None if guest - `device_id` (str|None): device corresponding to access token + Raises: InvalidClientTokenError if a user by that token exists, but the token is expired @@ -309,9 +311,9 @@ class Auth: if rights == "access": # first look in the database - r = await self._look_up_user_by_access_token(token) + r = await self.store.get_user_by_access_token(token) if r: - valid_until_ms = r["valid_until_ms"] + valid_until_ms = r.valid_until_ms if ( not allow_expired and valid_until_ms is not None @@ -328,7 +330,6 @@ class Auth: # otherwise it needs to be a valid macaroon try: user_id, guest = self._parse_and_validate_macaroon(token, rights) - user = UserID.from_string(user_id) if rights == "access": if not guest: @@ -354,23 +355,17 @@ class Auth: raise InvalidClientTokenError( "Guest access token used for regular user" ) - ret = { - "user": user, - "is_guest": True, - "shadow_banned": False, - "token_id": None, + + ret = TokenLookupResult( + user_id=user_id, + is_guest=True, # all guests get the same device id - "device_id": GUEST_DEVICE_ID, - } + device_id=GUEST_DEVICE_ID, + ) elif rights == "delete_pusher": # We don't store these tokens in the database - ret = { - "user": user, - "is_guest": False, - "shadow_banned": False, - "token_id": None, - "device_id": None, - } + + ret = TokenLookupResult(user_id=user_id, is_guest=False) else: raise RuntimeError("Unknown rights setting %s", rights) return ret @@ -479,31 +474,15 @@ class Auth: now = self.hs.get_clock().time_msec() return now < expiry - async def _look_up_user_by_access_token(self, token): - ret = await self.store.get_user_by_access_token(token) - if not ret: - return None - - # we use ret.get() below because *lots* of unit tests stub out - # get_user_by_access_token in a way where it only returns a couple of - # the fields. - user_info = { - "user": UserID.from_string(ret.get("name")), - "token_id": ret.get("token_id", None), - "is_guest": False, - "shadow_banned": ret.get("shadow_banned"), - "device_id": ret.get("device_id"), - "valid_until_ms": ret.get("valid_until_ms"), - } - return user_info - def get_appservice_by_req(self, request): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() - request.authenticated_entity = service.sender + request.requester = synapse.types.create_requester( + service.sender, app_service=service + ) return service async def is_server_admin(self, user: UserID) -> bool: diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 3862d9c08f..66d008d2f4 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -52,11 +52,11 @@ class ApplicationService: self, token, hostname, + id, + sender, url=None, namespaces=None, hs_token=None, - sender=None, - id=None, protocols=None, rate_limited=True, ip_range_whitelist=None, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 3a6b95631e..a0933fae88 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -154,7 +154,7 @@ class Authenticator: ) logger.debug("Request from %s", origin) - request.authenticated_entity = origin + request.requester = origin # If we get a valid signed request from the other side, its probably # alive diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 276594f3d9..ff103cbb92 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -991,17 +991,17 @@ class AuthHandler(BaseHandler): # This might return an awaitable, if it does block the log out # until it completes. result = provider.on_logged_out( - user_id=str(user_info["user"]), - device_id=user_info["device_id"], + user_id=user_info.user_id, + device_id=user_info.device_id, access_token=access_token, ) if inspect.isawaitable(result): await result # delete pushers associated with this access token - if user_info["token_id"] is not None: + if user_info.token_id is not None: await self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"],) + user_info.user_id, (user_info.token_id,) ) async def delete_access_tokens_for_user( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a6f1d21674..ed1ff62599 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler): 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = await self.auth.get_user_by_access_token(guest_access_token) - if not user_data["is_guest"] or user_data["user"].localpart != localpart: + if ( + not user_data.is_guest + or UserID.from_string(user_data.user_id).localpart != localpart + ): raise AuthError( 403, "Cannot register taken user ID without valid guest " @@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler): # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. user_tuple = await self.store.get_user_by_access_token(token) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id await self.pusher_pool.add_pusher( user_id=user_id, diff --git a/synapse/http/site.py b/synapse/http/site.py index ddb1770b09..5f0581dc3f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional +from typing import Optional, Union from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.types import Requester logger = logging.getLogger(__name__) @@ -54,9 +55,12 @@ class SynapseRequest(Request): Request.__init__(self, channel, *args, **kw) self.site = channel.site self._channel = channel # this is used by the tests - self.authenticated_entity = None self.start_time = 0.0 + # The requester, if authenticated. For federation requests this is the + # server name, for client requests this is the Requester object. + self.requester = None # type: Optional[Union[Requester, str]] + # we can't yet create the logcontext, as we don't know the method. self.logcontext = None # type: Optional[LoggingContext] @@ -271,11 +275,23 @@ class SynapseRequest(Request): # to the client (nb may be negative) response_send_time = self.finish_time - self._processing_finished_time - # need to decode as it could be raw utf-8 bytes - # from a IDN servname in an auth header - authenticated_entity = self.authenticated_entity - if authenticated_entity is not None and isinstance(authenticated_entity, bytes): - authenticated_entity = authenticated_entity.decode("utf-8", "replace") + # Convert the requester into a string that we can log + authenticated_entity = None + if isinstance(self.requester, str): + authenticated_entity = self.requester + elif isinstance(self.requester, Requester): + authenticated_entity = self.requester.authenticated_entity + + # If this is a request where the target user doesn't match the user who + # authenticated (e.g. and admin is puppetting a user) then we log both. + if self.requester.user.to_string() != authenticated_entity: + authenticated_entity = "{},{}".format( + authenticated_entity, self.requester.user.to_string(), + ) + elif self.requester is not None: + # This shouldn't happen, but we log it so we don't lose information + # and can see that we're doing something wrong. + authenticated_entity = repr(self.requester) # type: ignore[unreachable] # ...or could be raw utf-8 bytes in the User-Agent header. # N.B. if you don't do this, the logger explodes cryptically diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index e7cc74a5d2..f0c37eaf5e 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info("remote_join: %s into room: %s", user_id, room_id) @@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester # hopefully we're now on the master, so this won't recurse! event_id, stream_id = await self.member_handler.remote_reject_invite( diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index fc129dbaa7..8fa104c8d3 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info( "Got event to send with ID: %s into room: %s", event.event_id, event.room_id diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e7b17a7385..e5d07ce72a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -18,6 +18,8 @@ import logging import re from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import attr + from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) +@attr.s(frozen=True, slots=True) +class TokenLookupResult: + """Result of looking up an access token. + + Attributes: + user_id: The user that this token authenticates as + is_guest + shadow_banned + token_id: The ID of the access token looked up + device_id: The device associated with the token, if any. + valid_until_ms: The timestamp the token expires, if any. + token_owner: The "owner" of the token. This is either the same as the + user, or a server admin who is logged in as the user. + """ + + user_id = attr.ib(type=str) + is_guest = attr.ib(type=bool, default=False) + shadow_banned = attr.ib(type=bool, default=False) + token_id = attr.ib(type=Optional[int], default=None) + device_id = attr.ib(type=Optional[str], default=None) + valid_until_ms = attr.ib(type=Optional[int], default=None) + token_owner = attr.ib(type=str) + + # Make the token owner default to the user ID, which is the common case. + @token_owner.default + def _default_token_owner(self): + return self.user_id + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) @@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return is_trial @cached() - async def get_user_by_access_token(self, token: str) -> Optional[dict]: + async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]: """Get a user from the given access token. Args: token: The access token of a user. Returns: - None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise a `TokenLookupResult` """ return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token @@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) - def _query_for_auth(self, txn, token): + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ - SELECT users.name, + SELECT users.name as user_id, users.is_guest, users.shadow_banned, access_tokens.id as token_id, access_tokens.device_id, - access_tokens.valid_until_ms + access_tokens.valid_until_ms, + access_tokens.user_id as token_owner FROM users - INNER JOIN access_tokens on users.name = access_tokens.user_id + INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) WHERE token = ? """ txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) if rows: - return rows[0] + return TokenLookupResult(**rows[0]) return None diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql new file mode 100644 index 0000000000..00a9431a97 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Whether the access token is an admin token for controlling another user. +ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT; diff --git a/synapse/types.py b/synapse/types.py index 5bde67cc07..66bb5bac8d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -29,6 +29,7 @@ from typing import ( Tuple, Type, TypeVar, + Union, ) import attr @@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore # define a version of typing.Collection that works on python 3.5 @@ -74,6 +76,7 @@ class Requester( "shadow_banned", "device_id", "app_service", + "authenticated_entity", ], ) ): @@ -104,6 +107,7 @@ class Requester( "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, } @staticmethod @@ -129,16 +133,18 @@ class Requester( shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, + authenticated_entity=input["authenticated_entity"], ) def create_requester( - user_id, - access_token_id=None, - is_guest=False, - shadow_banned=False, - device_id=None, - app_service=None, + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: Optional[bool] = False, + shadow_banned: Optional[bool] = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, ): """ Create a new ``Requester`` object @@ -151,14 +157,27 @@ def create_requester( shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. Returns: Requester """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + return Requester( - user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cb6f29d670..0fd55f428a 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -29,6 +29,7 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest @@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} + user_info = TokenLookupResult( + user_id=self.test_user, token_id=5, device_id="device" + ) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): - user_info = {"name": self.test_user, "token_id": "ditto"} + user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( return_value=defer.succeed( - {"name": "@baldrick:matrix.org", "device_id": "device"} + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) ) @@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(macaroon.serialize()) ) - user = user_info["user"] - self.assertEqual(UserID.from_string(user_id), user) + self.assertEqual(user_id, user_info.user_id) # TODO: device_id should come from the macaroon, but currently comes # from the db. - self.assertEqual(user_info["device_id"], "device") + self.assertEqual(user_info.device_id, "device") @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): @@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(serialized) ) - user = user_info["user"] - is_guest = user_info["is_guest"] - self.assertEqual(UserID.from_string(user_id), user) - self.assertTrue(is_guest) + self.assertEqual(user_id, user_info.user_id) + self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks @@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase): if token != tok: return defer.succeed(None) return defer.succeed( - { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", + ) ) self.store.get_user_by_access_token = get_user diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 1e1f30d790..fe504d0869 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=True, + None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=False, + None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 236b608d58..0bffeb1150 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def setUp(self): self.service = ApplicationService( id="unique_identifier", + sender="@as:test", url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 4512c51311..875aaec2c6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): # make sure that our device ID has changed user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) - self.assertEqual(user_info["device_id"], retrieved_device_id) + self.assertEqual(user_info.device_id, retrieved_device_id) # make sure the device has the display name that was set from the login res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 9f6f21a6e2..2e0fea04af 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.info = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token,) ) - self.token_id = self.info["token_id"] + self.token_id = self.info.token_id self.requester = create_requester(self.user_id, access_token_id=self.token_id) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index d9993e6245..bcdcafa5a9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/push/test_http.py b/tests/push/test_http.py index b567868b02..8571924b29 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 2bdc6edbb1..67c27a089f 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): user_dict = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_dict["token_id"] + token_id = user_dict.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 2fc3a60fc5..98c3887bbf 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", ) self.hs.get_datastore().services_cache.append(appservice) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 6b582771fe..c8c7a90e5d 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store.get_user_by_access_token(self.tokens[1]) ) - self.assertDictContainsSubset( - {"name": self.user_id, "device_id": self.device_id}, result - ) - - self.assertTrue("token_id" in result) + self.assertEqual(result.user_id, self.user_id) + self.assertEqual(result.device_id, self.device_id) + self.assertIsNotNone(result.token_id) @defer.inlineCallbacks def test_user_delete_access_tokens(self): @@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user = yield defer.ensureDeferred( self.store.get_user_by_access_token(self.tokens[0]) ) - self.assertEqual(self.user_id, user["name"]) + self.assertEqual(self.user_id, user.user_id) # now delete the rest yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) -- cgit 1.5.1 From 450415154649b2355a8a38b6abf8de973fc6428a Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 30 Oct 2020 00:22:31 +0000 Subject: Fix optional parameter in stripped state storage method (#8688) Missed in #8671. --- changelog.d/8688.misc | 1 + synapse/storage/databases/main/events_worker.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8688.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8688.misc b/changelog.d/8688.misc new file mode 100644 index 0000000000..bef8dc425a --- /dev/null +++ b/changelog.d/8688.misc @@ -0,0 +1 @@ +Abstract some invite-related code in preparation for landing knocking. \ No newline at end of file diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5ae263827d..4732685f6e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -530,7 +530,7 @@ class EventsWorkerStore(SQLBaseStore): self, context: EventContext, state_types_to_include: List[EventTypes], - membership_user_id: Optional[str], + membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ Retrieve the stripped state from a room, given an event context to retrieve state -- cgit 1.5.1 From 46f4be94b410776ef3f922af2f437eb17631d2fa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 30 Oct 2020 10:55:24 +0000 Subject: Fix race for concurrent downloads of remote media. (#8682) Fixes #6755 --- changelog.d/8682.bugfix | 1 + synapse/rest/media/v1/media_repository.py | 165 +++++++----- synapse/rest/media/v1/media_storage.py | 30 ++- synapse/storage/databases/main/media_repository.py | 27 ++ tests/replication/test_multi_media_repo.py | 277 +++++++++++++++++++++ tests/server.py | 2 +- 6 files changed, 431 insertions(+), 71 deletions(-) create mode 100644 changelog.d/8682.bugfix create mode 100644 tests/replication/test_multi_media_repo.py (limited to 'synapse/storage/databases') diff --git a/changelog.d/8682.bugfix b/changelog.d/8682.bugfix new file mode 100644 index 0000000000..e61276aa05 --- /dev/null +++ b/changelog.d/8682.bugfix @@ -0,0 +1 @@ +Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 5cce7237a0..9cac74ebd8 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -305,15 +305,12 @@ class MediaRepository: # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. - if media_info: - file_id = media_info["filesystem_id"] - else: - file_id = random_string(24) - - file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() @@ -324,14 +321,34 @@ class MediaRepository: # Failed to find the file anywhere, lets download it. - media_info = await self._download_remote_file(server_name, media_id, file_id) + try: + media_info = await self._download_remote_file(server_name, media_id,) + except SynapseError: + raise + except Exception as e: + # An exception may be because we downloaded media in another + # process, so let's check if we magically have the media. + media_info = await self.store.get_cached_remote_media(server_name, media_id) + if not media_info: + raise e + + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + + # We generate thumbnails even if another process downloaded the media + # as a) it's conceivable that the other download request dies before it + # generates thumbnails, but mainly b) we want to be sure the thumbnails + # have finished being generated before responding to the client, + # otherwise they'll request thumbnails and get a 404 if they're not + # ready yet. + await self._generate_thumbnails( + server_name, media_id, file_id, media_info["media_type"] + ) responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file( - self, server_name: str, media_id: str, file_id: str - ) -> dict: + async def _download_remote_file(self, server_name: str, media_id: str,) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -346,6 +363,8 @@ class MediaRepository: The media info of the file. """ + file_id = random_string(24) + file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): @@ -401,22 +420,32 @@ class MediaRepository: await finish() - media_type = headers[b"Content-Type"][0].decode("ascii") - upload_name = get_filename_from_headers(headers) - time_now_ms = self.clock.time_msec() + media_type = headers[b"Content-Type"][0].decode("ascii") + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) logger.info("Stored remote media in file %r", fname) - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - media_info = { "media_type": media_type, "media_length": length, @@ -425,8 +454,6 @@ class MediaRepository: "filesystem_id": file_id, } - await self._generate_thumbnails(server_name, media_id, file_id, media_type) - return media_info def _get_thumbnail_requirements(self, media_type): @@ -692,42 +719,60 @@ class MediaRepository: if not t_byte_source: continue - try: - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - thumbnail=True, - thumbnail_width=t_width, - thumbnail_height=t_height, - thumbnail_method=t_method, - thumbnail_type=t_type, - url_cache=url_cache, - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - t_len = os.path.getsize(output_path) + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + url_cache=url_cache, + ) - # Write to database - if server_name: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = await self.store.get_remote_media_thumbnail( + server_name, media_id, t_width, t_height, t_type, + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) return {"width": m_width, "height": m_height} diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a9586fb0b7..268e0c8f50 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -52,6 +52,7 @@ class MediaStorage: storage_providers: Sequence["StorageProviderWrapper"], ): self.hs = hs + self.reactor = hs.get_reactor() self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers @@ -70,13 +71,16 @@ class MediaStorage: with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - await defer_to_thread( - self.hs.get_reactor(), _write_file_synchronously, source, f - ) + await self.write_to_file(source, f) await finish_cb() return fname + async def write_to_file(self, source: IO, output: IO): + """Asynchronously write the `source` to `output`. + """ + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + @contextlib.contextmanager def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as @@ -112,14 +116,20 @@ class MediaStorage: finished_called = [False] - async def finish(): - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - try: with open(fname, "wb") as f: + + async def finish(): + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + yield f, fname, finish except Exception: try: @@ -210,7 +220,7 @@ class MediaStorage: if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.hs.get_reactor() + open(local_path, "wb"), self.reactor ) await res.write_to_consumer(consumer) await consumer.wait() diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index daf57675d8..4b2f224718 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) + async def get_remote_media_thumbnail( + self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str, + ) -> Optional[Dict[str, Any]]: + """Fetch the thumbnail info of given width, height and type. + """ + + return await self.db_pool.simple_select_one( + table="remote_media_cache_thumbnails", + keyvalues={ + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": t_width, + "thumbnail_height": t_height, + "thumbnail_type": t_type, + }, + retcols=( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + allow_none=True, + desc="get_remote_media_thumbnail", + ) + async def store_remote_media_thumbnail( self, origin, diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py new file mode 100644 index 0000000000..77c261dbf7 --- /dev/null +++ b/tests/replication/test_multi_media_repo.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from binascii import unhexlify +from typing import Tuple + +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel +from twisted.web.server import Request + +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.server import HomeServer + +from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, FakeTransport + +logger = logging.getLogger(__name__) + +test_server_connection_factory = None + + +class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks running multiple media repos work correctly. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + self.reactor.lookups["example.com"] = "127.0.0.2" + + def default_config(self): + conf = super().default_config() + conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] + return conf + + def _get_media_req( + self, hs: HomeServer, target: str, media_id: str + ) -> Tuple[FakeChannel, Request]: + """Request some remote media from the given HS by calling the download + API. + + This then triggers an outbound request from the HS to the target. + + Returns: + The channel for the *client* request and the *outbound* request for + the media which the caller should respond to. + """ + + request, channel = self.make_request( + "GET", + "/{}/{}".format(target, media_id), + shorthand=False, + access_token=self.access_token, + ) + request.render(hs.get_media_repository_resource().children[b"download"]) + self.pump() + + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + + # build the test server + server_tls_protocol = _build_test_server(get_connection_factory()) + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_tls_protocol, self.reactor, client_protocol) + ) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_tls_protocol) + ) + + # fish the test server back out of the server-side TLS protocol. + http_server = server_tls_protocol.wrappedProtocol + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + + self.assertEqual(request.method, b"GET") + self.assertEqual( + request.path, + "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] + ) + + return channel, request + + def test_basic(self): + """Test basic fetching of remote media from a single worker. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + + channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") + + request.setResponseCode(200) + request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request.write(b"Hello!") + request.finish() + + self.pump(0.1) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], b"Hello!") + + def test_download_simple_file_race(self): + """Test that fetching remote media from two different processes at the + same time works. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request1.write(b"Hello!") + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], b"Hello!") + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request2.write(b"Hello!") + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], b"Hello!") + + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + + def test_download_image_race(self): + """Test that fetching remote *images* from two different processes at + the same time works. + + This checks that races generating thumbnails are handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_thumbnails() + + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") + + png_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request1.write(png_data) + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], png_data) + + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request2.write(png_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], png_data) + + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + +def get_connection_factory(): + # this needs to happen once, but not until we are ready to run the first test + global test_server_connection_factory + if test_server_connection_factory is None: + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[b"DNS:example.com"] + ) + return test_server_connection_factory + + +def _build_test_server(connection_creator): + """Construct a test server + + This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + + Args: + connection_creator (IOpenSSLServerConnectionCreator): thing to build + SSL connections + sanlist (list[bytes]): list of the SAN entries for the cert returned + by the server + + Returns: + TLSMemoryBIOProtocol + """ + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_factory = TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=server_factory + ) + + return server_tls_factory.buildProtocol(None) + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/server.py b/tests/server.py index b97003fa5a..3dd2cfc072 100644 --- a/tests/server.py +++ b/tests/server.py @@ -46,7 +46,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() - result = attr.ib(default=attr.Factory(dict)) + result = attr.ib(type=dict, default=attr.Factory(dict)) _producer = None @property -- cgit 1.5.1