diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 637a938bac..26eef6eb61 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,21 +15,31 @@
# limitations under the License.
import logging
import re
-from typing import List
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
-from synapse.appservice import ApplicationService, AppServiceTransaction
+from synapse.appservice import (
+ ApplicationService,
+ ApplicationServiceState,
+ AppServiceTransaction,
+)
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+ services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
@@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
- exclusive_user_regex = re.compile(exclusive_user_regex)
+ exclusive_user_pattern = re.compile(
+ exclusive_user_regex
+ ) # type: Optional[Pattern]
else:
# We handle this case specially otherwise the constructed regex
# will always match
- exclusive_user_regex = None
+ exclusive_user_pattern = None
- return exclusive_user_regex
+ return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
def get_app_services(self):
return self.services_cache
- def get_if_app_services_interested_in_user(self, user_id):
+ def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively)
"""
if self.exclusive_user_regex:
@@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
else:
return False
- def get_app_service_by_user_id(self, user_id):
+ def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
@@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
a user ID to an application service.
Args:
- user_id(str): The user ID to see if it is an application service.
+ user_id: The user ID to see if it is an application service.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.sender == user_id:
return service
return None
- def get_app_service_by_token(self, token):
+ def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice token.
Args:
- token (str): The application service token.
+ token: The application service token.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.token == token:
return service
return None
- def get_app_service_by_id(self, as_id):
+ def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice ID.
Args:
- as_id (str): The application service ID.
+ as_id: The application service ID.
Returns:
- synapse.appservice.ApplicationService or None.
+ The application service or None.
"""
for service in self.services_cache:
if service.id == as_id:
@@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
- async def get_appservices_by_state(self, state):
+ async def get_appservices_by_state(
+ self, state: ApplicationServiceState
+ ) -> List[ApplicationService]:
"""Get a list of application services based on their state.
Args:
- state(ApplicationServiceState): The state to filter on.
+ state: The state to filter on.
Returns:
A list of ApplicationServices, which may be empty.
"""
@@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
- async def get_appservice_state(self, service):
+ async def get_appservice_state(
+ self, service: ApplicationService
+ ) -> Optional[ApplicationServiceState]:
"""Get the application service state.
Args:
- service(ApplicationService): The service whose state to set.
+ service: The service whose state to set.
Returns:
- An ApplicationServiceState.
+ An ApplicationServiceState or none.
"""
result = await self.db_pool.simple_select_one(
"application_services_state",
@@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
- async def set_appservice_state(self, service, state) -> None:
+ async def set_appservice_state(
+ self, service: ApplicationService, state: ApplicationServiceState
+ ) -> None:
"""Set the application service state.
Args:
- service(ApplicationService): The service whose state to set.
- state(ApplicationServiceState): The connectivity state to apply.
+ service: The service whose state to set.
+ state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
@@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
"create_appservice_txn", _create_appservice_txn
)
- async def complete_appservice_txn(self, txn_id, service) -> None:
+ async def complete_appservice_txn(
+ self, txn_id: int, service: ApplicationService
+ ) -> None:
"""Completes an application service transaction.
Args:
- txn_id(str): The transaction ID being completed.
- service(ApplicationService): The application service which was sent
- this transaction.
+ txn_id: The transaction ID being completed.
+ service: The application service which was sent this transaction.
"""
txn_id = int(txn_id)
@@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
- if (last_txn_id + 1) != txn_id:
+ if (txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
@@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn
)
- async def get_oldest_unsent_txn(self, service):
- """Get the oldest transaction which has not been sent for this
- service.
+ async def get_oldest_unsent_txn(
+ self, service: ApplicationService
+ ) -> Optional[AppServiceTransaction]:
+ """Get the oldest transaction which has not been sent for this service.
Args:
- service(ApplicationService): The app service to get the oldest txn.
+ service: The app service to get the oldest txn.
Returns:
An AppServiceTransaction or None.
"""
@@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
service=service, id=entry["txn_id"], events=events, ephemeral=[]
)
- def _get_last_txn(self, txn, service_id):
+ def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,),
@@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
- async def set_appservice_last_pos(self, pos) -> None:
+ async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
- async def get_new_events_for_appservice(self, current_id, limit):
+ async def get_new_events_for_appservice(
+ self, current_id: int, limit: int
+ ) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
@@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
- self, service: ApplicationService, type: str, pos: int
+ self, service: ApplicationService, type: str, pos: Optional[int]
) -> None:
if type not in ("read_receipt", "presence"):
raise ValueError(
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 849bd5ba7a..3e26d5ba87 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
- pruned_json = frozendict_json_encoder.encode(
+ pruned_json = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
- pruned_json = frozendict_json_encoder.encode(
+ pruned_json = json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 87808c1483..90fb1a1f00 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
-from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -769,9 +769,7 @@ class PersistEventsStore:
logger.exception("")
raise
- metadata_json = frozendict_json_encoder.encode(
- event.internal_metadata.get_dict()
- )
+ metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -826,10 +824,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
- "internal_metadata": frozendict_json_encoder.encode(
+ "internal_metadata": json_encoder.encode(
event.internal_metadata.get_dict()
),
- "json": frozendict_json_encoder.encode(event_dict(event)),
+ "json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6e7f16f39c..4732685f6e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -31,6 +31,7 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import (
@@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
return event_map
+ async def get_stripped_room_state_from_event_context(
+ self,
+ context: EventContext,
+ state_types_to_include: List[EventTypes],
+ membership_user_id: Optional[str] = None,
+ ) -> List[JsonDict]:
+ """
+ Retrieve the stripped state from a room, given an event context to retrieve state
+ from as well as the state types to include. Optionally, include the membership
+ events from a specific user.
+
+ "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
+ are included from each state event.
+
+ Args:
+ context: The event context to retrieve state of the room from.
+ state_types_to_include: The type of state events to include.
+ membership_user_id: An optional user ID to include the stripped membership state
+ events of. This is useful when generating the stripped state of a room for
+ invites. We want to send membership events of the inviter, so that the
+ invitee can display the inviter's profile information if the room lacks any.
+
+ Returns:
+ A list of dictionaries, each representing a stripped state event from the room.
+ """
+ current_state_ids = await context.get_current_state_ids()
+
+ # We know this event is not an outlier, so this must be
+ # non-None.
+ assert current_state_ids is not None
+
+ # The state to include
+ state_to_include_ids = [
+ e_id
+ for k, e_id in current_state_ids.items()
+ if k[0] in state_types_to_include
+ or (membership_user_id and k == (EventTypes.Member, membership_user_id))
+ ]
+
+ state_to_include = await self.get_events(state_to_include_ids)
+
+ return [
+ {
+ "type": e.type,
+ "state_key": e.state_key,
+ "content": e.content,
+ "sender": e.sender,
+ }
+ for e in state_to_include.values()
+ ]
+
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
@@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
+ " LEFT JOIN room_memberships USING (event_id)"
+ " LEFT JOIN rejections USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" AND instance_name = ?"
" ORDER BY stream_ordering ASC"
@@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
+ " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
+ " LEFT JOIN room_memberships USING (event_id)"
+ " LEFT JOIN rejections USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" AND out.instance_name = ?"
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index daf57675d8..4b2f224718 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -452,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails",
)
+ async def get_remote_media_thumbnail(
+ self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+ ) -> Optional[Dict[str, Any]]:
+ """Fetch the thumbnail info of given width, height and type.
+ """
+
+ return await self.db_pool.simple_select_one(
+ table="remote_media_cache_thumbnails",
+ keyvalues={
+ "media_origin": origin,
+ "media_id": media_id,
+ "thumbnail_width": t_width,
+ "thumbnail_height": t_height,
+ "thumbnail_type": t_type,
+ },
+ retcols=(
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ "filesystem_id",
+ ),
+ allow_none=True,
+ desc="get_remote_media_thumbnail",
+ )
+
async def store_remote_media_thumbnail(
self,
origin,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e7b17a7385..e5d07ce72a 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -18,6 +18,8 @@ import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+import attr
+
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+ """Result of looking up an access token.
+
+ Attributes:
+ user_id: The user that this token authenticates as
+ is_guest
+ shadow_banned
+ token_id: The ID of the access token looked up
+ device_id: The device associated with the token, if any.
+ valid_until_ms: The timestamp the token expires, if any.
+ token_owner: The "owner" of the token. This is either the same as the
+ user, or a server admin who is logged in as the user.
+ """
+
+ user_id = attr.ib(type=str)
+ is_guest = attr.ib(type=bool, default=False)
+ shadow_banned = attr.ib(type=bool, default=False)
+ token_id = attr.ib(type=Optional[int], default=None)
+ device_id = attr.ib(type=Optional[str], default=None)
+ valid_until_ms = attr.ib(type=Optional[int], default=None)
+ token_owner = attr.ib(type=str)
+
+ # Make the token owner default to the user ID, which is the common case.
+ @token_owner.default
+ def _default_token_owner(self):
+ return self.user_id
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return is_trial
@cached()
- async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+ async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token.
Args:
token: The access token of a user.
Returns:
- None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise a `TokenLookupResult`
"""
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
- def _query_for_auth(self, txn, token):
+ def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
- SELECT users.name,
+ SELECT users.name as user_id,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
- access_tokens.valid_until_ms
+ access_tokens.valid_until_ms,
+ access_tokens.user_id as token_owner
FROM users
- INNER JOIN access_tokens on users.name = access_tokens.user_id
+ INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
"""
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
if rows:
- return rows[0]
+ return TokenLookupResult(**rows[0])
return None
diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
new file mode 100644
index 0000000000..00a9431a97
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
|