diff options
Diffstat (limited to 'synapse')
27 files changed, 729 insertions, 220 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 5da6c924fc..da52463531 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.39.0" +__version__ = "1.40.0rc2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 8d5f38b5d9..d119427ad8 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -151,6 +151,15 @@ class CacheConfig(Config): # entries are never evicted based on time. # #expiry_time: 30m + + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m """ def read_config(self, config, **kwargs): @@ -212,6 +221,10 @@ class CacheConfig(Config): else: self.expiry_time_msec = None + self.sync_response_cache_duration = self.parse_duration( + cache_config.get("sync_response_cache_duration", 0) + ) + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index dcd3ed1dac..ad4e6e61c3 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -71,7 +71,7 @@ handlers: # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR # logs will still be flushed immediately. buffer: - class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler + class: logging.handlers.MemoryHandler target: file # The capacity is the number of log lines that are buffered before # being written to disk. Increasing this will lead to better @@ -79,9 +79,6 @@ handlers: # be written to disk. capacity: 10 flushLevel: 30 # Flush for WARNING logs as well - # The period of time, in seconds, between forced flushes. - # Messages will not be delayed for longer than this time. - period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 0dfb3a227a..7481f3bf5f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from collections import namedtuple from typing import Dict, List +from urllib.request import getproxies_environment # type: ignore from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements @@ -22,6 +24,8 @@ from synapse.util.module_loader import load_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + DEFAULT_THUMBNAIL_SIZES = [ {"width": 32, "height": 32, "method": "crop"}, {"width": 96, "height": 96, "method": "crop"}, @@ -36,6 +40,9 @@ THUMBNAIL_SIZE_YAML = """\ # method: %(method)s """ +HTTP_PROXY_SET_WARNING = """\ +The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -180,12 +187,17 @@ class ContentRepositoryConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) + proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: - raise ConfigError( - "For security, you must specify an explicit target IP address " - "blacklist in url_preview_ip_range_blacklist for url previewing " - "to work" - ) + if "http" not in proxy_env or "https" not in proxy_env: + raise ConfigError( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + else: + if "http" in proxy_env or "https" in proxy_env: + logger.warning("".join(HTTP_PROXY_SET_WARNING)) # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. @@ -292,6 +304,8 @@ class ContentRepositoryConfig(Config): # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # + # Note: The value is ignored when an HTTP proxy is in use + # #url_preview_ip_range_blacklist: %(ip_range_blacklist)s diff --git a/synapse/config/server.py b/synapse/config/server.py index b9e0c0b300..187b4301a0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -960,6 +960,8 @@ class ServerConfig(Config): # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # + # Note: The value is ignored when an HTTP proxy is in use + # #ip_range_blacklist: %(ip_range_blacklist)s diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b7a10da15a..007d1a27dc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1290,7 +1290,7 @@ class FederationClient(FederationBase): ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: """Represents a single event in the result of a successful get_space_summary call. @@ -1299,12 +1299,13 @@ class FederationSpaceSummaryEventResult: object attributes. """ - event_type = attr.ib(type=str) - state_key = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + event_type: str + room_id: str + state_key: str + via: Sequence[str] # the raw data, including the above keys - data = attr.ib(type=JsonDict) + data: JsonDict @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": @@ -1321,6 +1322,10 @@ class FederationSpaceSummaryEventResult: if not isinstance(event_type, str): raise ValueError("Invalid event: 'event_type' must be a str") + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") + state_key = d.get("state_key") if not isinstance(state_key, str): raise ValueError("Invalid event: 'state_key' must be a str") @@ -1335,15 +1340,15 @@ class FederationSpaceSummaryEventResult: if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") - return cls(event_type, state_key, via, d) + return cls(event_type, room_id, state_key, via, d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms = attr.ib(type=Sequence[JsonDict]) - events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult]) + rooms: Sequence[JsonDict] + events: Sequence[FederationSpaceSummaryEventResult] @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 21a17cd2e8..4ab4046650 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -392,9 +392,6 @@ class ApplicationServicesHandler: protocols[p].append(info) def _merge_instances(infos: List[JsonDict]) -> JsonDict: - if not infos: - return {} - # Merge the 'instances' lists of multiple results, but just take # the other fields from the first as they ought to be identical # copy the result so as not to corrupt the cached one @@ -406,7 +403,9 @@ class ApplicationServicesHandler: return combined - return {p: _merge_instances(protocols[p]) for p in protocols.keys()} + return { + p: _merge_instances(protocols[p]) for p in protocols.keys() if protocols[p] + } async def _get_services_for_event( self, event: EventBase diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8197b60b76..8b602e3813 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -42,6 +42,7 @@ from twisted.internet import defer from synapse import event_auth from synapse.api.constants import ( + EventContentFields, EventTypes, Membership, RejectedReason, @@ -262,7 +263,12 @@ class FederationHandler(BaseHandler): state = None - # Get missing pdus if necessary. + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) @@ -432,6 +438,13 @@ class FederationHandler(BaseHandler): affected=event_id, ) + # A second round of checks for all events. Check that the event passes auth + # based on `auth_events`, this allows us to assert that the event would + # have been allowed at some point. If an event passes this check its OK + # for it to be used as part of a returned `/state` request, as either + # a) we received the event as part of the original join and so trust it, or + # b) we'll do a state resolution with existing state before it becomes + # part of the "current state", which adds more protection. await self._process_received_pdu(origin, pdu, state=state) async def _get_missing_events_for_pdu( @@ -889,6 +902,79 @@ class FederationHandler(BaseHandler): "resync_device_due_to_pdu", self._resync_device, event.sender ) + await self._handle_marker_event(origin, event) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + async def _resync_device(self, sender: str) -> None: """We have detected that the device list for the given user may be out of sync, so we try and resync them. @@ -1057,9 +1143,19 @@ class FederationHandler(BaseHandler): async def _maybe_backfill_inner( self, room_id: str, current_depth: int, limit: int ) -> bool: - extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) + oldest_events_with_depth = ( + await self.store.get_oldest_event_ids_with_depth_in_room(room_id) + ) + insertion_events_to_be_backfilled = ( + await self.store.get_insertion_event_backwards_extremities_in_room(room_id) + ) + logger.debug( + "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s", + oldest_events_with_depth, + insertion_events_to_be_backfilled, + ) - if not extremities: + if not oldest_events_with_depth and not insertion_events_to_be_backfilled: logger.debug("Not backfilling as no extremeties found.") return False @@ -1089,10 +1185,12 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = await self.store.get_successor_events(list(extremities)) + forward_event_ids = await self.store.get_successor_events( + list(oldest_events_with_depth) + ) extremities_events = await self.store.get_events( - forward_events, + forward_event_ids, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, ) @@ -1106,10 +1204,19 @@ class FederationHandler(BaseHandler): redact=False, check_history_visibility_only=True, ) + logger.debug( + "_maybe_backfill_inner: filtered_extremities %s", filtered_extremities + ) - if not filtered_extremities: + if not filtered_extremities and not insertion_events_to_be_backfilled: return False + extremities = { + **oldest_events_with_depth, + # TODO: insertion_events_to_be_backfilled is currently skipping the filtered_extremities checks + **insertion_events_to_be_backfilled, + } + # Check if we reached a point where we should start backfilling. sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0961dec5ab..8ffeabacf9 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -824,6 +824,7 @@ class IdentityHandler(BaseHandler): room_avatar_url: str, room_join_rules: str, room_name: str, + room_type: Optional[str], inviter_display_name: str, inviter_avatar_url: str, id_access_token: Optional[str] = None, @@ -843,6 +844,7 @@ class IdentityHandler(BaseHandler): notifications. room_join_rules: The join rules of the email (e.g. "public"). room_name: The m.room.name of the room. + room_type: The type of the room from its m.room.create event (e.g "m.space"). inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. @@ -869,6 +871,10 @@ class IdentityHandler(BaseHandler): "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } + + if room_type is not None: + invite_config["org.matrix.msc3288.room_type"] = room_type + # If a custom web client location is available, include it in the request. if self._web_client_location: invite_config["org.matrix.web_client_location"] = self._web_client_location diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index b9085bbccb..5fd4525700 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -70,7 +70,8 @@ class ReceiptsHandler(BaseHandler): ) if not is_in_room: logger.info( - "Ignoring receipt from %s as we're not in the room", + "Ignoring receipt for room %r from server %s as we're not in the room", + room_id, origin, ) continue diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65ad3efa6a..ba13196218 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,12 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( AuthError, Codes, @@ -1237,6 +1242,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if room_name_event: room_name = room_name_event.content.get("name", "") + room_type = None + room_create_event = room_state.get((EventTypes.Create, "")) + if room_create_event: + room_type = room_create_event.content.get(EventContentFields.ROOM_TYPE) + room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: @@ -1263,6 +1273,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, + room_type=room_type, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url, id_access_token=id_access_token, diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 5f7d4602bd..3eb232c83e 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -16,7 +16,17 @@ import itertools import logging import re from collections import deque -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) import attr @@ -116,20 +126,22 @@ class SpaceSummaryHandler: max_children = max_rooms_per_space if processed_rooms else None if is_in_room: - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( requester, None, room_id, suggested_only, max_children ) + events: Collection[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children + logger.debug( "Query of local room %s returned events %s", room_id, ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) - - if room: - rooms_result.append(room) else: - fed_rooms, fed_events = await self._summarize_remote_room( + fed_rooms = await self._summarize_remote_room( queue_entry, suggested_only, max_children, @@ -141,12 +153,10 @@ class SpaceSummaryHandler: # user is not permitted see. # # Filter the returned results to only what is accessible to the user. - room_ids = set() events = [] - for room in fed_rooms: - fed_room_id = room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id # The room should only be included in the summary if: # a. the user is in the room; @@ -189,21 +199,17 @@ class SpaceSummaryHandler: # The user can see the room, include it! if include_room: rooms_result.append(room) - room_ids.add(fed_room_id) + events.extend(room_entry.children) # All rooms returned don't need visiting again (even if the user # didn't have access to them). processed_rooms.add(fed_room_id) - for event in fed_events: - if event.get("room_id") in room_ids: - events.append(event) - logger.debug( "Query of %s returned rooms %s, events %s", room_id, - [room.get("room_id") for room in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events], + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) # the room we queried may or may not have been returned, but don't process @@ -283,20 +289,20 @@ class SpaceSummaryHandler: # already done this room continue - logger.debug("Processing room %s", room_id) - - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( None, origin, room_id, suggested_only, max_rooms_per_space ) processed_rooms.add(room_id) - if room: - rooms_result.append(room) - events_result.extend(events) + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children) - # add any children to the queue - room_queue.extend(edge_event["state_key"] for edge_event in events) + # add any children to the queue + room_queue.extend( + edge_event["state_key"] for edge_event in room_entry.children + ) return {"rooms": rooms_result, "events": events_result} @@ -307,7 +313,7 @@ class SpaceSummaryHandler: room_id: str, suggested_only: bool, max_children: Optional[int], - ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]: + ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -326,21 +332,16 @@ class SpaceSummaryHandler: to a server-set limit. Returns: - A tuple of: - The room information, if the room should be returned to the - user. None, otherwise. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + A room entry if the room should be returned. None, otherwise. """ if not await self._is_room_accessible(room_id, requester, origin): - return None, () + return None room_entry = await self._build_room_entry(room_id) # If the room is not a space, return just the room information. if room_entry.get("room_type") != RoomTypes.SPACE: - return room_entry, () + return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) @@ -363,7 +364,7 @@ class SpaceSummaryHandler: ) ) - return room_entry, events_result + return _RoomEntry(room_id, room_entry, events_result) async def _summarize_remote_room( self, @@ -371,7 +372,7 @@ class SpaceSummaryHandler: suggested_only: bool, max_children: Optional[int], exclude_rooms: Iterable[str], - ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + ) -> Iterable["_RoomEntry"]: """ Request room entries and a list of event entries for a given room by querying a remote server. @@ -386,11 +387,7 @@ class SpaceSummaryHandler: Rooms IDs which do not need to be summarized. Returns: - A tuple of: - An iterable of rooms. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + An iterable of room entries. """ room_id = room.room_id logger.info("Requesting summary for %s via %s", room_id, room.via) @@ -414,11 +411,30 @@ class SpaceSummaryHandler: e, exc_info=logger.isEnabledFor(logging.DEBUG), ) - return (), () + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue - return res.rooms, tuple( - ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild - ) + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results async def _is_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] @@ -606,10 +622,21 @@ class SpaceSummaryHandler: return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class _RoomQueueEntry: - room_id = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + room_id: str + via: Sequence[str] + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children: Collection[JsonDict] = () def _has_valid_via(e: EventBase) -> bool: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f30bfcc93c..590642f510 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -269,14 +269,22 @@ class SyncHandler: self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( - hs.get_clock(), "sync" - ) self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() self.state_store = self.storage.state + # TODO: flush cache entries on subsequent sync request. + # Once we get the next /sync request (ie, one with the same access token + # that sets 'since' to 'next_batch'), we know that device won't need a + # cached result any more, and we could flush the entry from the cache to save + # memory. + self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( + hs.get_clock(), + "sync", + timeout_ms=hs.config.caches.sync_response_cache_duration, + ) + # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache: ExpiringCache[ Tuple[str, Optional[str]], LruCache[str, str] diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0cb651a400..a97c448595 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -335,7 +335,8 @@ class TypingWriterHandler(FollowerTypingHandler): ) if not is_in_room: logger.info( - "Ignoring typing update from %s as we're not in the room", + "Ignoring typing update for room %r from server %s as we're not in the room", + room_id, origin, ) return diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index a6c212f300..af5fc407a8 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -45,6 +45,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): self._flushing_thread: Thread = Thread( name="PeriodicallyFlushingMemoryHandler flushing thread", target=self._flush_periodically, + daemon=True, ) self._flushing_thread.start() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 473812b8e2..1cc13fc97b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester from synapse.util import Clock from synapse.util.caches.descriptors import cached @@ -174,6 +174,16 @@ class ModuleApi: """The application name configured in the homeserver's configuration.""" return self._hs.config.email.email_app_name + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get user info by user_id + + Args: + user_id: Fully qualified user id. + Returns: + UserInfo object if a user was found, otherwise None + """ + return await self._store.get_userinfo_by_id(user_id) + async def get_user_by_req( self, req: SynapseRequest, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 4b98979b47..d9ab836cd8 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -43,7 +43,7 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - body = parse_json_object_from_request(request) + body = parse_json_object_from_request(request, allow_empty_body=True) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c8015a3848..95d2caff62 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -941,13 +941,13 @@ class DatabasePool: `lock` should generally be set to True (the default), but can be set to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. + 1. there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + 2. we somehow know that we are the only thread which will be updating + this table. + As an additional note, this parameter only matters for old SQLite versions + because we will use native upserts otherwise. Args: table: The table to upsert into diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1edc96042b..1f0a39eac4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): """ @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + def _claim_e2e_one_time_key_simple( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, device_id, algorithm)) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + + self.db_pool.simple_delete_one_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, ) - fallback_sql = ( - "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - result = {} - delete = [] - used_fallbacks = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is not None: - key_id, key_json = otk_row - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - else: - # no one-time key available, so see if there's a fallback - # key - txn.execute(fallback_sql, (user_id, device_id, algorithm)) - fallback_row = txn.fetchone() - if fallback_row is not None: - key_id, key_json, used = fallback_row - device_result[algorithm + ":" + key_id] = key_json - if not used: - used_fallbacks.append( - (user_id, device_id, algorithm, key_id) - ) - - # drop any one-time keys that were claimed - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" + + return f"{algorithm}:{key_id}", key_json + + @trace + def _claim_e2e_one_time_key_returning( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + # We can use RETURNING to do the fetch and DELETE in once step. + sql = """ + DELETE FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + AND key_id IN ( + SELECT key_id FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + ) + RETURNING key_id, key_json + """ + + txn.execute( + sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + return f"{algorithm}:{key_id}", key_json + + results = {} + for user_id, device_id, algorithm in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False + + row = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + db_autocommit=db_autocommit, + ) + if row: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} ) - # mark fallback keys as used - for user_id, device_id, algorithm, key_id in used_fallbacks: - self.db_pool.simple_update_txn( - txn, - "e2e_fallback_keys_json", - { + device_results[row[0]] = row[1] + continue + + # No one-time key available, so see if there's a fallback + # key + row = await self.db_pool.simple_select_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + retcols=("key_id", "key_json", "used"), + desc="_get_fallback_key", + allow_none=True, + ) + if row is None: + continue + + key_id = row["key_id"] + key_json = row["key_json"] + used = row["used"] + + # Mark fallback key as used if not already. + if not used: + await self.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, }, - {"used": True}, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) ) - return result + device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) + device_results[f"{algorithm}:{key_id}"] = key_json - return await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + return results class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 44018c1c31..bddf5ef192 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -671,27 +671,97 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - async def get_oldest_events_with_depth_in_room(self, room_id): + async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]: + """Gets the oldest events(backwards extremities) in the room along with the + aproximate depth. + + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + event extremeties. If the current depth is close to the depth of given + oldest event, we can trigger a backfill. + + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + # Assemble a dictionary with event_id -> depth for the oldest events + # we know of in the room. Backwards extremeties are the oldest + # events we know of in the room but we only know of them because + # some other event referenced them by prev_event and aren't peristed + # in our database yet (meaning we don't know their depth + # specifically). So we need to look for the aproximate depth from + # the events connected to the current backwards extremeties. + sql = """ + SELECT b.event_id, MAX(e.depth) FROM events as e + /** + * Get the edge connections from the event_edges table + * so we can see whether this event's prev_events points + * to a backward extremity in the next join. + */ + INNER JOIN event_edges as g + ON g.event_id = e.event_id + /** + * We find the "oldest" events in the room by looking for + * events connected to backwards extremeties (oldest events + * in the room that we know of so far). + */ + INNER JOIN event_backward_extremities as b + ON g.prev_event_id = b.event_id + WHERE b.room_id = ? AND g.is_state is ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id, False)) + + return dict(txn) + return await self.db_pool.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, + "get_oldest_event_ids_with_depth_in_room", + get_oldest_event_ids_with_depth_in_room_txn, room_id, ) - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) + async def get_insertion_event_backwards_extremities_in_room( + self, room_id + ) -> Dict[str, int]: + """Get the insertion events we know about that we haven't backfilled yet. - txn.execute(sql, (room_id, False)) + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + insertion event. If the current depth is close to the depth of given + insertion event, we can trigger a backfill. - return dict(txn) + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + sql = """ + SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + /* We only want insertion events that are also marked as backwards extremities */ + INNER JOIN insertion_event_extremities as b USING (event_id) + /* Get the depth of the insertion event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE b.room_id = ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id,)) + + return dict(txn) + + return await self.db_pool.runInteraction( + "get_insertion_event_backwards_extremities_in_room", + get_insertion_event_backwards_extremities_in_room_txn, + room_id, + ) async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs @@ -1041,7 +1111,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if row[1] not in event_results: queue.put((-row[0], row[1])) - # Navigate up the DAG by prev_event txn.execute(query, (event_id, False, limit - len(event_results))) prev_event_id_results = txn.fetchall() logger.debug( @@ -1136,6 +1205,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: + await self.db_pool.simple_upsert( + table="insertion_event_extremities", + keyvalues={"event_id": event_id}, + values={ + "event_id": event_id, + "room_id": room_id, + }, + insertion_values={}, + desc="insert_insertion_extremity", + lock=False, + ) + async def insert_received_event_to_staging( self, origin: str, event: EventBase ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 86baf397fb..40b53274fb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1845,6 +1845,18 @@ class PersistEventsStore: }, ) + # When we receive an event with a `chunk_id` referencing the + # `next_chunk_id` of the insertion event, we can remove it from the + # `insertion_event_extremities` table. + sql = """ + DELETE FROM insertion_event_extremities WHERE event_id IN ( + SELECT event_id FROM insertion_events + WHERE next_chunk_id = ? + ) + """ + + txn.execute(sql, (chunk_id,)) + def _handle_redaction(self, txn, redacted_event_id): """Handles receiving a redaction and checking whether we need to remove any redacted relations from the database. @@ -2101,15 +2113,17 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ + # From the events passed in, add all of the prev events as backwards extremities. + # Ignore any events that are already backwards extrems or outliers. query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" " )" " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" " )" ) @@ -2123,6 +2137,8 @@ class PersistEventsStore: ], ) + # Delete all these events that we've already fetched and now know that their + # prev events are the new backwards extremeties. query = ( "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3c86adab56..375463e4e9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,7 +14,6 @@ import logging import threading -from collections import namedtuple from typing import ( Collection, Container, @@ -27,6 +26,7 @@ from typing import ( overload, ) +import attr from constantly import NamedConstant, Names from typing_extensions import Literal @@ -42,7 +42,11 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.logging.context import ( + PreserveLoggingContext, + current_context, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +@attr.s(slots=True, auto_attribs=True) +class _EventCacheEntry: + event: EventBase + redacted_event: Optional[EventBase] class EventRedactBehaviour(Names): @@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore): max_size=hs.config.caches.event_cache_size, ) + # Map from event ID to a deferred that will result in a map from event + # ID to cache entry. Note that the returned dict may not have the + # requested event in it if the event isn't in the DB. + self._current_event_fetches: Dict[ + str, ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = {} + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore): return events - async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore): Args: - event_ids (Iterable[str]): The event_ids of the events to fetch + event_ids: The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events. If False, + allow_rejected: Whether to include rejected events. If False, rejected events are omitted from the response. Returns: - Dict[str, _EventCacheEntry]: - map from event id to result + map from event id to result """ event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected + event_ids, ) - missing_events_ids = [e for e in event_ids if e not in event_entry_map] + missing_events_ids = {e for e in event_ids if e not in event_entry_map} + + # We now look up if we're already fetching some of the events in the DB, + # if so we wait for those lookups to finish instead of pulling the same + # events out of the DB multiple times. + already_fetching: Dict[str, defer.Deferred] = {} + + for event_id in missing_events_ids: + deferred = self._current_event_fetches.get(event_id) + if deferred is not None: + # We're already pulling the event out of the DB. Add the deferred + # to the collection of deferreds to wait on. + already_fetching[event_id] = deferred.observe() + + missing_events_ids.difference_update(already_fetching) if missing_events_ids: log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) + for event_id in missing_events_ids: + self._current_event_fetches[event_id] = fetching_deferred + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = await self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) - event_entry_map.update(missing_events) + event_entry_map.update(missing_events) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + if already_fetching: + # Wait for the other event requests to finish and add their results + # to ours. + results = await make_deferred_yieldable( + defer.gatherResults( + already_fetching.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + for result in results: + event_entry_map.update(result) + + if not allow_rejected: + event_entry_map = { + event_id: entry + for event_id, entry in event_entry_map.items() + if not entry.event.rejected_reason + } return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches + def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, _EventCacheEntry]: + """Fetch events from the caches. - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics + May return rejected events. - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics """ event_map = {} @@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore): if not ret: continue - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None + event_map[event_id] = ret return event_map @@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - async def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db( + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. + May return rejected events. + Returned events will be added to the cache for future lookups. Unknown events are omitted from the response. Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. + event_ids: The event_ids of the events to fetch Returns: - Dict[str, _EventCacheEntry]: - map from event id to result. May return extra events which - weren't asked for. + map from event id to result. May return extra events which + weren't asked for. """ fetched_events = {} events_to_fetch = event_ids @@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason = row["rejected_reason"] - if not allow_rejected and rejected_reason: - continue - # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6ad1a0cf7f..14670c2881 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 68f1b40ea6..e8157ba3d4 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) + event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) missing_member_event_ids = [] for event_id in member_event_ids: ev_entry = event_map.get(event_id) - if ev_entry: + if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: users_in_room[ev_entry.event.state_key] = ProfileInfo( display_name=ev_entry.event.content.get("displayname", None), diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 36340a652a..fd4dd67d91 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 61 +SCHEMA_VERSION = 62 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql new file mode 100644 index 0000000000..b731ef284a --- /dev/null +++ b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql @@ -0,0 +1,24 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Add a table that keeps track of which "insertion" events need to be backfilled +CREATE TABLE IF NOT EXISTS insertion_event_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS insertion_event_extremities_event_id ON insertion_event_extremities(event_id); +CREATE INDEX IF NOT EXISTS insertion_event_extremities_room_id ON insertion_event_extremities(room_id); diff --git a/synapse/types.py b/synapse/types.py index 429bb013d2..80fa903c4b 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info): # and return that one key for key_id, key_data in keys.items(): return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool |