diff options
Diffstat (limited to 'synapse')
36 files changed, 606 insertions, 140 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 18f2ec3ae8..19f30c273c 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -50,7 +50,7 @@ class Filtering(object): # many definitions. top_level_definitions = [ - "presence" + "presence", "account_data" ] room_level_definitions = [ @@ -139,6 +139,10 @@ class FilterCollection(object): self.filter_json.get("presence", {}) ) + self.account_data = Filter( + self.filter_json.get("account_data", {}) + ) + def timeline_limit(self): return self.room_timeline_filter.limit() @@ -151,6 +155,9 @@ class FilterCollection(object): def filter_presence(self, events): return self.presence_filter.filter(events) + def filter_account_data(self, events): + return self.account_data.filter(events) + def filter_room_state(self, events): return self.room_state_filter.filter(events) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 127b4da4f8..6b164fd2d1 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -165,7 +165,7 @@ class BaseFederationServlet(object): if code is None: continue - server.register_path(method, pattern, self._wrap(code)) + server.register_paths(method, (pattern,), self._wrap(code)) class FederationSendServlet(BaseFederationServlet): diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 1d35d3b7dc..fe773bee9b 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -29,9 +29,10 @@ class AccountDataEventSource(object): last_stream_id = from_key current_stream_id = yield self.store.get_max_account_data_stream_id() - tags = yield self.store.get_updated_tags(user_id, last_stream_id) results = [] + tags = yield self.store.get_updated_tags(user_id, last_stream_id) + for room_id, room_tags in tags.items(): results.append({ "type": "m.tag", @@ -39,6 +40,24 @@ class AccountDataEventSource(object): "room_id": room_id, }) + account_data, room_account_data = ( + yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) + + for account_data_type, content in account_data.items(): + results.append({ + "type": account_data_type, + "content": content, + }) + + for room_id, account_data in room_account_data.items(): + for account_data_type, content in account_data.items(): + results.append({ + "type": account_data_type, + "content": content, + "room_id": room_id, + }) + defer.returnValue((results, current_stream_id)) @defer.inlineCallbacks diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 64c57375f7..e959ce50be 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -359,6 +359,10 @@ class MessageHandler(BaseHandler): tags_by_room = yield self.store.get_tags_for_user(user_id) + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user(user_id) + ) + public_room_ids = yield self.store.get_public_room_ids() limit = pagin_config.limit @@ -436,14 +440,22 @@ class MessageHandler(BaseHandler): for c in current_state.values() ] - account_data = [] + account_data_events = [] tags = tags_by_room.get(event.room_id) if tags: - account_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - d["account_data"] = account_data + + account_data = account_data_by_room.get(event.room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + d["account_data"] = account_data_events except: logger.exception("Failed to get snapshot") @@ -456,9 +468,17 @@ class MessageHandler(BaseHandler): consumeErrors=True ).addErrback(unwrapFirstError) + account_data_events = [] + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + ret = { "rooms": rooms_ret, "presence": presence, + "account_data": account_data_events, "receipts": receipt, "end": now_token.to_string(), } @@ -498,14 +518,22 @@ class MessageHandler(BaseHandler): user_id, room_id, pagin_config, membership, member_event_id, is_guest ) - account_data = [] + account_data_events = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - account_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - result["account_data"] = account_data + + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + result["account_data"] = account_data_events defer.returnValue(result) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 877328b29e..943ce368ef 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -100,6 +100,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [ "next_batch", # Token for the next sync "presence", # List of presence events for the user. + "account_data", # List of account_data events for the user. "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. @@ -195,6 +196,12 @@ class SyncHandler(BaseHandler): ) ) + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user( + sync_config.user.to_string() + ) + ) + tags_by_room = yield self.store.get_tags_for_user( sync_config.user.to_string() ) @@ -211,6 +218,7 @@ class SyncHandler(BaseHandler): timeline_since_token=timeline_since_token, ephemeral_by_room=ephemeral_by_room, tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, ) joined.append(room_sync) elif event.membership == Membership.INVITE: @@ -230,11 +238,13 @@ class SyncHandler(BaseHandler): leave_token=leave_token, timeline_since_token=timeline_since_token, tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, ) archived.append(room_sync) defer.returnValue(SyncResult( presence=presence, + account_data=self.account_data_for_user(account_data), joined=joined, invited=invited, archived=archived, @@ -244,7 +254,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def full_state_sync_for_joined_room(self, room_id, sync_config, now_token, timeline_since_token, - ephemeral_by_room, tags_by_room): + ephemeral_by_room, tags_by_room, + account_data_by_room): """Sync a room for a client which is starting without any state Returns: A Deferred JoinedSyncResult. @@ -262,19 +273,38 @@ class SyncHandler(BaseHandler): state=current_state, ephemeral=ephemeral_by_room.get(room_id, []), account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), )) - def account_data_for_room(self, room_id, tags_by_room): - account_data = [] + def account_data_for_user(self, account_data): + account_data_events = [] + + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + return account_data_events + + def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): + account_data_events = [] tags = tags_by_room.get(room_id) if tags is not None: - account_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - return account_data + + account_data = account_data_by_room.get(room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + return account_data_events @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): @@ -341,7 +371,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def full_state_sync_for_archived_room(self, room_id, sync_config, leave_event_id, leave_token, - timeline_since_token, tags_by_room): + timeline_since_token, tags_by_room, + account_data_by_room): """Sync a room for a client which is starting without any state Returns: A Deferred JoinedSyncResult. @@ -358,7 +389,7 @@ class SyncHandler(BaseHandler): timeline=batch, state=leave_state, account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), )) @@ -415,6 +446,13 @@ class SyncHandler(BaseHandler): since_token.account_data_key, ) + account_data, account_data_by_room = ( + yield self.store.get_updated_account_data_for_user( + sync_config.user.to_string(), + since_token.account_data_key, + ) + ) + joined = [] archived = [] if len(room_events) <= timeline_limit: @@ -469,7 +507,7 @@ class SyncHandler(BaseHandler): state=state, ephemeral=ephemeral_by_room.get(room_id, []), account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), ) logger.debug("Result for room %s: %r", room_id, room_sync) @@ -492,14 +530,15 @@ class SyncHandler(BaseHandler): for room_id in joined_room_ids: room_sync = yield self.incremental_sync_with_gap_for_room( room_id, sync_config, since_token, now_token, - ephemeral_by_room, tags_by_room + ephemeral_by_room, tags_by_room, account_data_by_room ) if room_sync: joined.append(room_sync) for leave_event in leave_events: room_sync = yield self.incremental_sync_for_archived_room( - sync_config, leave_event, since_token, tags_by_room + sync_config, leave_event, since_token, tags_by_room, + account_data_by_room ) archived.append(room_sync) @@ -510,6 +549,7 @@ class SyncHandler(BaseHandler): defer.returnValue(SyncResult( presence=presence, + account_data=self.account_data_for_user(account_data), joined=joined, invited=invited, archived=archived, @@ -566,7 +606,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def incremental_sync_with_gap_for_room(self, room_id, sync_config, since_token, now_token, - ephemeral_by_room, tags_by_room): + ephemeral_by_room, tags_by_room, + account_data_by_room): """ Get the incremental delta needed to bring the client up to date for the room. Gives the client the most recent events and the changes to state. @@ -606,7 +647,7 @@ class SyncHandler(BaseHandler): state=state, ephemeral=ephemeral_by_room.get(room_id, []), account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), ) @@ -616,7 +657,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def incremental_sync_for_archived_room(self, sync_config, leave_event, - since_token, tags_by_room): + since_token, tags_by_room, + account_data_by_room): """ Get the incremental delta needed to bring the client up to date for the archived room. Returns: @@ -654,7 +696,7 @@ class SyncHandler(BaseHandler): timeline=batch, state=state_events_delta, account_data=self.account_data_for_room( - leave_event.room_id, tags_by_room + leave_event.room_id, tags_by_room, account_data_by_room ), ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 50feea6f1c..ef75be742c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -120,7 +120,7 @@ class HttpServer(object): """ Interface for registering callbacks on a HTTP server """ - def register_path(self, method, path_pattern, callback): + def register_paths(self, method, path_patterns, callback): """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -129,7 +129,7 @@ class HttpServer(object): Args: method (str): The method to listen to. - path_pattern (str): The regex used to match requests. + path_patterns (list<SRE_Pattern>): The regex used to match requests. callback (function): The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. @@ -165,10 +165,11 @@ class JsonResource(HttpServer, resource.Resource): self.version_string = hs.version_string self.hs = hs - def register_path(self, method, path_pattern, callback): - self.path_regexs.setdefault(method, []).append( - self._PathEntry(path_pattern, callback) - ) + def register_paths(self, method, path_patterns, callback): + for path_pattern in path_patterns: + self.path_regexs.setdefault(method, []).append( + self._PathEntry(path_pattern, callback) + ) def render(self, request): """ This gets called by twisted every time someone sends us a request. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 9cda17fcf8..32b6d6cd72 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -19,7 +19,6 @@ from synapse.api.errors import SynapseError import logging - logger = logging.getLogger(__name__) @@ -102,12 +101,13 @@ class RestServlet(object): def register(self, http_server): """ Register this servlet with the given HTTP server. """ - if hasattr(self, "PATTERN"): - pattern = self.PATTERN + if hasattr(self, "PATTERNS"): + patterns = self.PATTERNS for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): if hasattr(self, "on_%s" % (method,)): method_handler = getattr(self, "on_%s" % (method,)) - http_server.register_path(method, pattern, method_handler) + http_server.register_paths(method, patterns, method_handler) + else: raise NotImplementedError("RestServlet must register something.") diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index bdde43864c..0103697889 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import logging @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)") + PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)", releases=()) @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 504a5e432f..7ae3839a19 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def client_path_pattern(path_regex): +def client_path_patterns(path_regex, releases=(0,)): """Creates a regex compiled client path with the correct client path prefix. @@ -37,7 +37,13 @@ def client_path_pattern(path_regex): Returns: SRE_Pattern """ - return re.compile("^" + CLIENT_PREFIX + path_regex) + patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)] + unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release) + patterns.append(re.compile("^" + new_prefix + path_regex)) + return patterns class ClientV1RestServlet(RestServlet): diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 240eedac75..f488e2dd41 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError, Codes from synapse.types import RoomAlias -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import logging @@ -32,7 +32,7 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): - PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$") + PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") @defer.inlineCallbacks def on_GET(self, request, room_alias): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 3e1750d1a1..41b97e7d15 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.events.utils import serialize_event import logging @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class EventStreamRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/events$") + PATTERNS = client_path_patterns("/events$") DEFAULT_LONGPOLL_TIME_MS = 30000 @@ -72,7 +72,7 @@ class EventStreamRestServlet(ClientV1RestServlet): # TODO: Unit test gets, with and without auth, with different kinds of events. class EventRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") + PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$") def __init__(self, hs): super(EventRestServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 856a70f297..9ad3df8a9f 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -16,12 +16,12 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns # TODO: Needs unit testing class InitialSyncRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/initialSync$") + PATTERNS = client_path_patterns("/initialSync$") @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 720d6358e7..b0b641e430 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes from synapse.http.client import SimpleHttpClient from synapse.types import UserID -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) class LoginRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login$") + PATTERNS = client_path_patterns("/login$", releases=()) PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" @@ -238,7 +238,7 @@ class LoginRestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/saml2") + PATTERNS = client_path_patterns("/login/saml2", releases=()) def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) @@ -282,7 +282,7 @@ class SAML2RestServlet(ClientV1RestServlet): # TODO Delete this after all CAS clients switch to token login instead class CasRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/cas") + PATTERNS = client_path_patterns("/login/cas", releases=()) def __init__(self, hs): super(CasRestServlet, self).__init__(hs) @@ -293,7 +293,7 @@ class CasRestServlet(ClientV1RestServlet): class CasRedirectServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/cas/redirect") + PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) def __init__(self, hs): super(CasRedirectServlet, self).__init__(hs) @@ -316,7 +316,7 @@ class CasRedirectServlet(ClientV1RestServlet): class CasTicketServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/cas/ticket") + PATTERNS = client_path_patterns("/login/cas/ticket", releases=()) def __init__(self, hs): super(CasTicketServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 48533f9d60..e0949fe4bb 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.types import UserID -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import logging @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status") + PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -73,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)") + PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 3218e47025..e6c6e5d024 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,14 +16,14 @@ """ This module contains REST servlets to do with profile: /profile/<paths> """ from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.types import UserID import simplejson as json class ProfileDisplaynameRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") + PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -56,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url") + PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -89,7 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)") + PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index b0870db1ac..edf5b0ca41 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import ( SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError ) -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) @@ -31,7 +31,7 @@ import simplejson as json class PushRuleRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/pushrules/.*$") + PATTERNS = client_path_patterns("/pushrules/.*$") SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index a110c0a4f0..6f465035b4 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,13 +17,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json class PusherRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/pushers/set$") + PATTERNS = client_path_patterns("/pushers/set$") @defer.inlineCallbacks def on_POST(self, request): diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index a56834e365..5b95d63e25 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor @@ -48,7 +48,7 @@ class RegisterRestServlet(ClientV1RestServlet): handler doesn't have a concept of multi-stages or sessions. """ - PATTERN = client_path_pattern("/register$") + PATTERNS = client_path_patterns("/register$", releases=()) def __init__(self, hs): super(RegisterRestServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6952d269ec..d86d266465 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,7 +16,7 @@ """ This module contains REST servlets to do with rooms: /rooms/<paths> """ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership @@ -34,16 +34,16 @@ class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here def register(self, http_server): - PATTERN = "/createRoom" - register_txn_path(self, PATTERN, http_server) + PATTERNS = "/createRoom" + register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_path("OPTIONS", - client_path_pattern("/rooms(?:/.*)?$"), - self.on_OPTIONS) + http_server.register_paths("OPTIONS", + client_path_patterns("/rooms(?:/.*)?$"), + self.on_OPTIONS) # define CORS for /createRoom[/txnid] - http_server.register_path("OPTIONS", - client_path_pattern("/createRoom(?:/.*)?$"), - self.on_OPTIONS) + http_server.register_paths("OPTIONS", + client_path_patterns("/createRoom(?:/.*)?$"), + self.on_OPTIONS) @defer.inlineCallbacks def on_PUT(self, request, txn_id): @@ -103,18 +103,18 @@ class RoomStateEventRestServlet(ClientV1RestServlet): state_key = ("/rooms/(?P<room_id>[^/]*)/state/" "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") - http_server.register_path("GET", - client_path_pattern(state_key), - self.on_GET) - http_server.register_path("PUT", - client_path_pattern(state_key), - self.on_PUT) - http_server.register_path("GET", - client_path_pattern(no_state_key), - self.on_GET_no_state_key) - http_server.register_path("PUT", - client_path_pattern(no_state_key), - self.on_PUT_no_state_key) + http_server.register_paths("GET", + client_path_patterns(state_key), + self.on_GET) + http_server.register_paths("PUT", + client_path_patterns(state_key), + self.on_PUT) + http_server.register_paths("GET", + client_path_patterns(no_state_key, releases=()), + self.on_GET_no_state_key) + http_server.register_paths("PUT", + client_path_patterns(no_state_key, releases=()), + self.on_PUT_no_state_key) def on_GET_no_state_key(self, request, room_id, event_type): return self.on_GET(request, room_id, event_type, "") @@ -170,8 +170,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERN = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") - register_txn_path(self, PATTERN, http_server, with_get=True) + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") + register_txn_path(self, PATTERNS, http_server, with_get=True) @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): @@ -215,8 +215,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet): def register(self, http_server): # /join/$room_identifier[/$txn_id] - PATTERN = ("/join/(?P<room_identifier>[^/]*)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/join/(?P<room_identifier>[^/]*)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): @@ -280,7 +280,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): # TODO: Needs unit testing class PublicRoomListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/publicRooms$") + PATTERNS = client_path_patterns("/publicRooms$") @defer.inlineCallbacks def on_GET(self, request): @@ -291,7 +291,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMemberListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -328,7 +328,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): # TODO: Needs better unit testing class RoomMessageListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -351,7 +351,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomStateRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -368,7 +368,7 @@ class RoomStateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomInitialSyncRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -384,7 +384,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomTriggerBackfill(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$") + PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/backfill$", releases=()) def __init__(self, hs): super(RoomTriggerBackfill, self).__init__(hs) @@ -408,7 +408,7 @@ class RoomTriggerBackfill(ClientV1RestServlet): class RoomEventContext(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" ) @@ -447,9 +447,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERN = ("/rooms/(?P<room_id>[^/]*)/" - "(?P<membership_action>join|invite|leave|ban|kick|forget)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" + "(?P<membership_action>join|invite|leave|ban|kick|forget)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): @@ -543,8 +543,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): def register(self, http_server): - PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): @@ -582,7 +582,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" ) @@ -615,7 +615,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): class SearchRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/search$" ) @@ -655,20 +655,20 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): http_server : The http_server to register paths with. with_get: True to also register respective GET paths for the PUTs. """ - http_server.register_path( + http_server.register_paths( "POST", - client_path_pattern(regex_string + "$"), + client_path_patterns(regex_string + "$"), servlet.on_POST ) - http_server.register_path( + http_server.register_paths( "PUT", - client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), + client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), servlet.on_PUT ) if with_get: - http_server.register_path( + http_server.register_paths( "GET", - client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), + client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), servlet.on_GET ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index eb7c57cade..1567a03c89 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import hmac @@ -24,7 +24,7 @@ import base64 class VoipRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/voip/turnServer$") + PATTERNS = client_path_patterns("/voip/turnServer$") @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index a108132346..d7b59c84d1 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -23,6 +23,7 @@ from . import ( keys, tokenrefresh, tags, + account_data, ) from synapse.http.server import JsonResource @@ -46,3 +47,4 @@ class ClientV2AlphaRestResource(JsonResource): keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) + account_data.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 4540e8dcf7..7b8b879c03 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -27,7 +27,7 @@ import simplejson logger = logging.getLogger(__name__) -def client_v2_pattern(path_regex): +def client_v2_patterns(path_regex, releases=(0,)): """Creates a regex compiled client path with the correct client path prefix. @@ -37,7 +37,13 @@ def client_v2_pattern(path_regex): Returns: SRE_Pattern """ - return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) + patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) + patterns.append(re.compile("^" + new_prefix + path_regex)) + return patterns def parse_request_allow_empty(request): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1970ad3458..6f1c33f75b 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,7 +20,7 @@ from synapse.api.errors import LoginError, SynapseError, Codes from synapse.http.servlet import RestServlet from synapse.util.async import run_on_reactor -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request import logging @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class PasswordRestServlet(RestServlet): - PATTERN = client_v2_pattern("/account/password") + PATTERNS = client_v2_patterns("/account/password", releases=()) def __init__(self, hs): super(PasswordRestServlet, self).__init__() @@ -89,7 +89,7 @@ class PasswordRestServlet(RestServlet): class ThreepidRestServlet(RestServlet): - PATTERN = client_v2_pattern("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid", releases=()) def __init__(self, hs): super(ThreepidRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py new file mode 100644 index 0000000000..5b8f454bf1 --- /dev/null +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import client_v2_patterns + +from synapse.http.servlet import RestServlet +from synapse.api.errors import AuthError, SynapseError + +from twisted.internet import defer + +import logging + +import simplejson as json + +logger = logging.getLogger(__name__) + + +class AccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + """ + PATTERNS = client_v2_patterns( + "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" + ) + + def __init__(self, hs): + super(AccountDataServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, account_data_type): + auth_user, _, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid JSON") + + max_id = yield self.store.add_account_data_for_user( + user_id, account_data_type, body + ) + + yield self.notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + +class RoomAccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + """ + PATTERNS = client_v2_patterns( + "/user/(?P<user_id>[^/]*)" + "/rooms/(?P<room_id>[^/]*)" + "/account_data/(?P<account_data_type>[^/]*)" + ) + + def __init__(self, hs): + super(RoomAccountDataServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, room_id, account_data_type): + auth_user, _, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid JSON") + + if not isinstance(body, dict): + raise ValueError("Expected a JSON object") + + max_id = yield self.store.add_account_data_to_room( + user_id, room_id, account_data_type, body + ) + + yield self.notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + AccountDataServlet(hs).register(http_server) + RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 4c726f05f5..fb5947a141 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -20,7 +20,7 @@ from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern +from ._base import client_v2_patterns import logging @@ -97,7 +97,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web") + PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 97956a4b91..3cd0364b56 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, SynapseError from synapse.http.servlet import RestServlet from synapse.types import UserID -from ._base import client_v2_pattern +from ._base import client_v2_patterns import simplejson as json import logging @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") + PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") def __init__(self, hs): super(GetFilterRestServlet, self).__init__() @@ -65,7 +65,7 @@ class GetFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") + PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter") def __init__(self, hs): super(CreateFilterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 820d33336f..c55e85920f 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -21,7 +21,7 @@ from synapse.types import UserID from canonicaljson import encode_canonical_json -from ._base import client_v2_pattern +from ._base import client_v2_patterns import simplejson as json import logging @@ -54,7 +54,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)") + PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)") def __init__(self, hs): super(KeyUploadServlet, self).__init__() @@ -154,12 +154,13 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/keys/query(?:" "/(?P<user_id>[^/]*)(?:" "/(?P<device_id>[^/]*)" ")?" - ")?" + ")?", + releases=() ) def __init__(self, hs): @@ -245,10 +246,11 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/keys/claim(?:/?|(?:/" "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)" - ")?)" + ")?)", + releases=() ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 788acd4adb..aa214e13b6 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern +from ._base import client_v2_patterns import logging @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ReceiptRestServlet(RestServlet): - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/rooms/(?P<room_id>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)" "/(?P<event_id>[^/]*)$" diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f899376311..b2b89652c6 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -19,7 +19,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request import logging import hmac @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) class RegisterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/register") + PATTERNS = client_v2_patterns("/register") def __init__(self, hs): super(RegisterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 775f49885b..4efe802487 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,7 +25,7 @@ from synapse.events.utils import ( serialize_event, format_event_for_client_v2_without_room_id, ) from synapse.api.filtering import FilterCollection -from ._base import client_v2_pattern +from ._base import client_v2_patterns import copy import logging @@ -69,7 +69,7 @@ class SyncRestServlet(RestServlet): } """ - PATTERN = client_v2_pattern("/sync$") + PATTERNS = client_v2_patterns("/sync$") ALLOWED_PRESENCE = set(["online", "offline"]) def __init__(self, hs): @@ -144,6 +144,9 @@ class SyncRestServlet(RestServlet): ) response_content = { + "account_data": self.encode_account_data( + sync_result.account_data, filter, time_now + ), "presence": self.encode_presence( sync_result.presence, filter, time_now ), @@ -165,6 +168,9 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": filter.filter_presence(formatted)} + def encode_account_data(self, events, filter, time_now): + return {"events": filter.filter_account_data(events)} + def encode_joined(self, rooms, filter, time_now, token_id): """ Encode the joined rooms in a sync result diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index ba7223be11..b5d0db5569 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import client_v2_pattern +from ._base import client_v2_patterns from synapse.http.servlet import RestServlet from synapse.api.errors import AuthError, SynapseError @@ -31,7 +31,7 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" ) @@ -56,7 +56,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" ) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 901e777983..5a63afd51e 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request class TokenRefreshRestServlet(RestServlet): @@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ - PATTERN = client_v2_pattern("/tokenrefresh") + PATTERNS = client_v2_patterns("/tokenrefresh") def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e7443f2838..c46b653f11 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -42,6 +42,7 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore from .search import SearchStore from .tags import TagsStore +from .account_data import AccountDataStore import logging @@ -73,6 +74,7 @@ class DataStore(RoomMemberStore, RoomStore, EndToEndKeyStore, SearchStore, TagsStore, + AccountDataStore, ): def __init__(self, hs): diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py new file mode 100644 index 0000000000..d1829f84e8 --- /dev/null +++ b/synapse/storage/account_data.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + +import ujson as json +import logging + +logger = logging.getLogger(__name__) + + +class AccountDataStore(SQLBaseStore): + + def get_account_data_for_user(self, user_id): + """Get all the client account_data for a user. + + Args: + user_id(str): The user to get the account_data for. + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_account_data_for_user_txn(txn): + rows = self._simple_select_list_txn( + txn, "account_data", {"user_id": user_id}, + ["account_data_type", "content"] + ) + + global_account_data = { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + rows = self._simple_select_list_txn( + txn, "room_account_data", {"user_id": user_id}, + ["room_id", "account_data_type", "content"] + ) + + by_room = {} + for row in rows: + room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = json.loads(row["content"]) + + return (global_account_data, by_room) + + return self.runInteraction( + "get_account_data_for_user", get_account_data_for_user_txn + ) + + def get_account_data_for_room(self, user_id, room_id): + """Get all the client account_data for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + Returns: + A deferred dict of the room account_data + """ + def get_account_data_for_room_txn(txn): + rows = self._simple_select_list_txn( + txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, + ["account_data_type", "content"] + ) + + return { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + return self.runInteraction( + "get_account_data_for_room", get_account_data_for_room_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed. + + Args: + user_id(str): The user to get the account_data for. + stream_id(int): The point in the stream since which to get updates + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_updated_account_data_for_user_txn(txn): + sql = ( + "SELECT account_data_type, content FROM account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + global_account_data = { + row[0]: json.loads(row[1]) for row in txn.fetchall() + } + + sql = ( + "SELECT room_id, account_data_type, content FROM room_account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + account_data_by_room = {} + for row in txn.fetchall(): + room_account_data = account_data_by_room.setdefault(row[0], {}) + room_account_data[row[1]] = json.loads(row[2]) + + return (global_account_data, account_data_by_room) + + return self.runInteraction( + "get_updated_account_data_for_user", get_updated_account_data_for_user_txn + ) + + @defer.inlineCallbacks + def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + def add_account_data_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + values={ + "stream_id": next_id, + "content": content_json, + } + ) + self._update_max_stream_id(txn, next_id) + + with (yield self._account_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction( + "add_room_account_data", add_account_data_txn, next_id + ) + + result = yield self._account_data_id_gen.get_max_token(self) + defer.returnValue(result) + + @defer.inlineCallbacks + def add_account_data_for_user(self, user_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + def add_account_data_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="account_data", + keyvalues={ + "user_id": user_id, + "account_data_type": account_data_type, + }, + values={ + "stream_id": next_id, + "content": content_json, + } + ) + self._update_max_stream_id(txn, next_id) + + with (yield self._account_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction( + "add_user_account_data", add_account_data_txn, next_id + ) + + result = yield self._account_data_id_gen.get_max_token(self) + defer.returnValue(result) + + def _update_max_stream_id(self, txn, next_id): + """Update the max stream_id + + Args: + txn: The database cursor + next_id(int): The the revision to advance to. + """ + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/schema/delta/26/account_data.sql index 3198a0d29c..48ad9cc6b8 100644 --- a/synapse/storage/schema/delta/26/account_data.sql +++ b/synapse/storage/schema/delta/26/account_data.sql @@ -15,3 +15,26 @@ ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; + + +CREATE TABLE IF NOT EXISTS account_data( + user_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) +); + + +CREATE TABLE IF NOT EXISTS room_account_data( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) +); + + +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index f6d826cc59..f520f60c6c 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -48,8 +48,8 @@ class TagsStore(SQLBaseStore): Args: user_id(str): The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to lists of tag - strings. + A deferred dict mapping from room_id strings to dicts mapping from + tag strings to tag content. """ deferred = self._simple_select_list( |