diff --git a/changelog.d/8616.misc b/changelog.d/8616.misc
new file mode 100644
index 0000000000..385b14063e
--- /dev/null
+++ b/changelog.d/8616.misc
@@ -0,0 +1 @@
+Change schema to support access tokens belonging to one user but granting access to another.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 526cb58c5f..bfcaf68b2a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -190,10 +191,6 @@ class Auth:
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
- request.authenticated_entity = user_id
- opentracing.set_tag("authenticated_entity", user_id)
- opentracing.set_tag("appservice_id", app_service.id)
-
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
@@ -203,31 +200,38 @@ class Auth:
device_id="dummy-device", # stubbed
)
- return synapse.types.create_requester(user_id, app_service=app_service)
+ requester = synapse.types.create_requester(
+ user_id, app_service=app_service
+ )
+
+ request.requester = user_id
+ opentracing.set_tag("authenticated_entity", user_id)
+ opentracing.set_tag("user_id", user_id)
+ opentracing.set_tag("appservice_id", app_service.id)
+
+ return requester
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
- user = user_info["user"]
- token_id = user_info["token_id"]
- is_guest = user_info["is_guest"]
- shadow_banned = user_info["shadow_banned"]
+ token_id = user_info.token_id
+ is_guest = user_info.is_guest
+ shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
- user_id = user.to_string()
- if await self.store.is_account_expired(user_id, self.clock.time_msec()):
+ if await self.store.is_account_expired(
+ user_info.user_id, self.clock.time_msec()
+ ):
raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
- # device_id may not be present if get_user_by_access_token has been
- # stubbed out.
- device_id = user_info.get("device_id")
+ device_id = user_info.device_id
- if user and access_token and ip_addr:
+ if access_token and ip_addr:
await self.store.insert_client_ip(
- user_id=user.to_string(),
+ user_id=user_info.token_owner,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
@@ -241,19 +245,23 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
- request.authenticated_entity = user.to_string()
- opentracing.set_tag("authenticated_entity", user.to_string())
- if device_id:
- opentracing.set_tag("device_id", device_id)
-
- return synapse.types.create_requester(
- user,
+ requester = synapse.types.create_requester(
+ user_info.user_id,
token_id,
is_guest,
shadow_banned,
device_id,
app_service=app_service,
+ authenticated_entity=user_info.token_owner,
)
+
+ request.requester = requester
+ opentracing.set_tag("authenticated_entity", user_info.token_owner)
+ opentracing.set_tag("user_id", user_info.user_id)
+ if device_id:
+ opentracing.set_tag("device_id", device_id)
+
+ return requester
except KeyError:
raise MissingClientTokenError()
@@ -284,7 +292,7 @@ class Auth:
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
- ) -> dict:
+ ) -> TokenLookupResult:
""" Validate access token and get user_id from it
Args:
@@ -293,13 +301,7 @@ class Auth:
allow this
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
- Returns:
- dict that includes:
- `user` (UserID)
- `is_guest` (bool)
- `shadow_banned` (bool)
- `token_id` (int|None): access token id. May be None if guest
- `device_id` (str|None): device corresponding to access token
+
Raises:
InvalidClientTokenError if a user by that token exists, but the token is
expired
@@ -309,9 +311,9 @@ class Auth:
if rights == "access":
# first look in the database
- r = await self._look_up_user_by_access_token(token)
+ r = await self.store.get_user_by_access_token(token)
if r:
- valid_until_ms = r["valid_until_ms"]
+ valid_until_ms = r.valid_until_ms
if (
not allow_expired
and valid_until_ms is not None
@@ -328,7 +330,6 @@ class Auth:
# otherwise it needs to be a valid macaroon
try:
user_id, guest = self._parse_and_validate_macaroon(token, rights)
- user = UserID.from_string(user_id)
if rights == "access":
if not guest:
@@ -354,23 +355,17 @@ class Auth:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
- ret = {
- "user": user,
- "is_guest": True,
- "shadow_banned": False,
- "token_id": None,
+
+ ret = TokenLookupResult(
+ user_id=user_id,
+ is_guest=True,
# all guests get the same device id
- "device_id": GUEST_DEVICE_ID,
- }
+ device_id=GUEST_DEVICE_ID,
+ )
elif rights == "delete_pusher":
# We don't store these tokens in the database
- ret = {
- "user": user,
- "is_guest": False,
- "shadow_banned": False,
- "token_id": None,
- "device_id": None,
- }
+
+ ret = TokenLookupResult(user_id=user_id, is_guest=False)
else:
raise RuntimeError("Unknown rights setting %s", rights)
return ret
@@ -479,31 +474,15 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
- async def _look_up_user_by_access_token(self, token):
- ret = await self.store.get_user_by_access_token(token)
- if not ret:
- return None
-
- # we use ret.get() below because *lots* of unit tests stub out
- # get_user_by_access_token in a way where it only returns a couple of
- # the fields.
- user_info = {
- "user": UserID.from_string(ret.get("name")),
- "token_id": ret.get("token_id", None),
- "is_guest": False,
- "shadow_banned": ret.get("shadow_banned"),
- "device_id": ret.get("device_id"),
- "valid_until_ms": ret.get("valid_until_ms"),
- }
- return user_info
-
def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
- request.authenticated_entity = service.sender
+ request.requester = synapse.types.create_requester(
+ service.sender, app_service=service
+ )
return service
async def is_server_admin(self, user: UserID) -> bool:
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 3862d9c08f..66d008d2f4 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -52,11 +52,11 @@ class ApplicationService:
self,
token,
hostname,
+ id,
+ sender,
url=None,
namespaces=None,
hs_token=None,
- sender=None,
- id=None,
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3a6b95631e..a0933fae88 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -154,7 +154,7 @@ class Authenticator:
)
logger.debug("Request from %s", origin)
- request.authenticated_entity = origin
+ request.requester = origin
# If we get a valid signed request from the other side, its probably
# alive
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 276594f3d9..ff103cbb92 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -991,17 +991,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
- user_id=str(user_info["user"]),
- device_id=user_info["device_id"],
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
access_token=access_token,
)
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token
- if user_info["token_id"] is not None:
+ if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"],)
+ user_info.user_id, (user_info.token_id,)
)
async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a6f1d21674..ed1ff62599 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = await self.auth.get_user_by_access_token(guest_access_token)
- if not user_data["is_guest"] or user_data["user"].localpart != localpart:
+ if (
+ not user_data.is_guest
+ or UserID.from_string(user_data.user_id).localpart != localpart
+ ):
raise AuthError(
403,
"Cannot register taken user ID without valid guest "
@@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
user_id=user_id,
diff --git a/synapse/http/site.py b/synapse/http/site.py
index ddb1770b09..5f0581dc3f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
import contextlib
import logging
import time
-from typing import Optional
+from typing import Optional, Union
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.types import Requester
logger = logging.getLogger(__name__)
@@ -54,9 +55,12 @@ class SynapseRequest(Request):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self._channel = channel # this is used by the tests
- self.authenticated_entity = None
self.start_time = 0.0
+ # The requester, if authenticated. For federation requests this is the
+ # server name, for client requests this is the Requester object.
+ self.requester = None # type: Optional[Union[Requester, str]]
+
# we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext]
@@ -271,11 +275,23 @@ class SynapseRequest(Request):
# to the client (nb may be negative)
response_send_time = self.finish_time - self._processing_finished_time
- # need to decode as it could be raw utf-8 bytes
- # from a IDN servname in an auth header
- authenticated_entity = self.authenticated_entity
- if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
- authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+ # Convert the requester into a string that we can log
+ authenticated_entity = None
+ if isinstance(self.requester, str):
+ authenticated_entity = self.requester
+ elif isinstance(self.requester, Requester):
+ authenticated_entity = self.requester.authenticated_entity
+
+ # If this is a request where the target user doesn't match the user who
+ # authenticated (e.g. and admin is puppetting a user) then we log both.
+ if self.requester.user.to_string() != authenticated_entity:
+ authenticated_entity = "{},{}".format(
+ authenticated_entity, self.requester.user.to_string(),
+ )
+ elif self.requester is not None:
+ # This shouldn't happen, but we log it so we don't lose information
+ # and can see that we're doing something wrong.
+ authenticated_entity = repr(self.requester) # type: ignore[unreachable]
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index e7cc74a5d2..f0c37eaf5e 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
requester = Requester.deserialize(self.store, content["requester"])
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite(
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index fc129dbaa7..8fa104c8d3 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
+ request.requester = requester
logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e7b17a7385..e5d07ce72a 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -18,6 +18,8 @@ import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+import attr
+
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+ """Result of looking up an access token.
+
+ Attributes:
+ user_id: The user that this token authenticates as
+ is_guest
+ shadow_banned
+ token_id: The ID of the access token looked up
+ device_id: The device associated with the token, if any.
+ valid_until_ms: The timestamp the token expires, if any.
+ token_owner: The "owner" of the token. This is either the same as the
+ user, or a server admin who is logged in as the user.
+ """
+
+ user_id = attr.ib(type=str)
+ is_guest = attr.ib(type=bool, default=False)
+ shadow_banned = attr.ib(type=bool, default=False)
+ token_id = attr.ib(type=Optional[int], default=None)
+ device_id = attr.ib(type=Optional[str], default=None)
+ valid_until_ms = attr.ib(type=Optional[int], default=None)
+ token_owner = attr.ib(type=str)
+
+ # Make the token owner default to the user ID, which is the common case.
+ @token_owner.default
+ def _default_token_owner(self):
+ return self.user_id
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return is_trial
@cached()
- async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+ async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token.
Args:
token: The access token of a user.
Returns:
- None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise a `TokenLookupResult`
"""
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
- def _query_for_auth(self, txn, token):
+ def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
- SELECT users.name,
+ SELECT users.name as user_id,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
- access_tokens.valid_until_ms
+ access_tokens.valid_until_ms,
+ access_tokens.user_id as token_owner
FROM users
- INNER JOIN access_tokens on users.name = access_tokens.user_id
+ INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
"""
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
if rows:
- return rows[0]
+ return TokenLookupResult(**rows[0])
return None
diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
new file mode 100644
index 0000000000..00a9431a97
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
diff --git a/synapse/types.py b/synapse/types.py
index 5bde67cc07..66bb5bac8d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -29,6 +29,7 @@ from typing import (
Tuple,
Type,
TypeVar,
+ Union,
)
import attr
@@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
if TYPE_CHECKING:
+ from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore
# define a version of typing.Collection that works on python 3.5
@@ -74,6 +76,7 @@ class Requester(
"shadow_banned",
"device_id",
"app_service",
+ "authenticated_entity",
],
)
):
@@ -104,6 +107,7 @@ class Requester(
"shadow_banned": self.shadow_banned,
"device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None,
+ "authenticated_entity": self.authenticated_entity,
}
@staticmethod
@@ -129,16 +133,18 @@ class Requester(
shadow_banned=input["shadow_banned"],
device_id=input["device_id"],
app_service=appservice,
+ authenticated_entity=input["authenticated_entity"],
)
def create_requester(
- user_id,
- access_token_id=None,
- is_guest=False,
- shadow_banned=False,
- device_id=None,
- app_service=None,
+ user_id: Union[str, "UserID"],
+ access_token_id: Optional[int] = None,
+ is_guest: Optional[bool] = False,
+ shadow_banned: Optional[bool] = False,
+ device_id: Optional[str] = None,
+ app_service: Optional["ApplicationService"] = None,
+ authenticated_entity: Optional[str] = None,
):
"""
Create a new ``Requester`` object
@@ -151,14 +157,27 @@ def create_requester(
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
Returns:
Requester
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
+
+ if authenticated_entity is None:
+ authenticated_entity = user_id.to_string()
+
return Requester(
- user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+ user_id,
+ access_token_id,
+ is_guest,
+ shadow_banned,
+ device_id,
+ app_service,
+ authenticated_entity,
)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cb6f29d670..0fd55f428a 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError,
ResourceLimitError,
)
+from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID
from tests import unittest
@@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
+ user_info = TokenLookupResult(
+ user_id=self.test_user, token_id=5, device_id="device"
+ )
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
- user_info = {"name": self.test_user, "token_id": "ditto"}
+ user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info)
)
@@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(
- {"name": "@baldrick:matrix.org", "device_id": "device"}
+ TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
)
)
@@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize())
)
- user = user_info["user"]
- self.assertEqual(UserID.from_string(user_id), user)
+ self.assertEqual(user_id, user_info.user_id)
# TODO: device_id should come from the macaroon, but currently comes
# from the db.
- self.assertEqual(user_info["device_id"], "device")
+ self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
@@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized)
)
- user = user_info["user"]
- is_guest = user_info["is_guest"]
- self.assertEqual(UserID.from_string(user_id), user)
- self.assertTrue(is_guest)
+ self.assertEqual(user_id, user_info.user_id)
+ self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks
@@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok:
return defer.succeed(None)
return defer.succeed(
- {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ TokenLookupResult(
+ user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
+ )
)
self.store.get_user_by_access_token = get_user
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 1e1f30d790..fe504d0869 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=True,
+ None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
@@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
- None, "example.com", id="foo", rate_limited=False,
+ None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
)
as_requester = create_requester("@user:example.com", app_service=appservice)
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 236b608d58..0bffeb1150 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
id="unique_identifier",
+ sender="@as:test",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 4512c51311..875aaec2c6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
- self.assertEqual(user_info["device_id"], retrieved_device_id)
+ self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9f6f21a6e2..2e0fea04af 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
)
- self.token_id = self.info["token_id"]
+ self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index d9993e6245..bcdcafa5a9 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b567868b02..8571924b29 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
@@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_tuple["token_id"]
+ token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 2bdc6edbb1..67c27a089f 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
- token_id = user_dict["token_id"]
+ token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2fc3a60fc5..98c3887bbf 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
)
self.hs.get_datastore().services_cache.append(appservice)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 6b582771fe..c8c7a90e5d 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1])
)
- self.assertDictContainsSubset(
- {"name": self.user_id, "device_id": self.device_id}, result
- )
-
- self.assertTrue("token_id" in result)
+ self.assertEqual(result.user_id, self.user_id)
+ self.assertEqual(result.device_id, self.device_id)
+ self.assertIsNotNone(result.token_id)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0])
)
- self.assertEqual(self.user_id, user["name"])
+ self.assertEqual(self.user_id, user.user_id)
# now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
|