diff options
Diffstat (limited to '')
52 files changed, 2390 insertions, 333 deletions
diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index dbc6f6ffa5..15f8e1c48b 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -39,43 +39,43 @@ ROOMDOMAIN="meet.jit.si" #ROOMDOMAIN="conference.jitsi.vuc.me" class TrivialMatrixClient: - def __init__(self, access_token): - self.token = None - self.access_token = access_token - - def getEvent(self): - while True: - url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000" - if self.token: - url += "&from="+self.token - req = grequests.get(url) - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print "incoming from matrix",obj - if 'end' not in obj: - continue - self.token = obj['end'] - if len(obj['chunk']): - return obj['chunk'][0] - - def joinRoom(self, roomId): - url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token - print url - headers={ 'Content-Type': 'application/json' } - req = grequests.post(url, headers=headers, data='{}') - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print "response: ",obj - - def sendEvent(self, roomId, evType, event): - url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token - print url - print json.dumps(event) - headers={ 'Content-Type': 'application/json' } - req = grequests.post(url, headers=headers, data=json.dumps(event)) - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print "response: ",obj + def __init__(self, access_token): + self.token = None + self.access_token = access_token + + def getEvent(self): + while True: + url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000" + if self.token: + url += "&from="+self.token + req = grequests.get(url) + resps = grequests.map([req]) + obj = json.loads(resps[0].content) + print "incoming from matrix",obj + if 'end' not in obj: + continue + self.token = obj['end'] + if len(obj['chunk']): + return obj['chunk'][0] + + def joinRoom(self, roomId): + url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token + print url + headers={ 'Content-Type': 'application/json' } + req = grequests.post(url, headers=headers, data='{}') + resps = grequests.map([req]) + obj = json.loads(resps[0].content) + print "response: ",obj + + def sendEvent(self, roomId, evType, event): + url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token + print url + print json.dumps(event) + headers={ 'Content-Type': 'application/json' } + req = grequests.post(url, headers=headers, data=json.dumps(event)) + resps = grequests.map([req]) + obj = json.loads(resps[0].content) + print "response: ",obj @@ -83,178 +83,178 @@ xmppClients = {} def matrixLoop(): - while True: - ev = matrixCli.getEvent() - print ev - if ev['type'] == 'm.room.member': - print 'membership event' - if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME: - roomId = ev['room_id'] - print "joining room %s" % (roomId) - matrixCli.joinRoom(roomId) - elif ev['type'] == 'm.room.message': - if ev['room_id'] in xmppClients: - print "already have a bridge for that user, ignoring" - continue - print "got message, connecting" - xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.invite': - print "Incoming call" - #sdp = ev['content']['offer']['sdp'] - #print "sdp: %s" % (sdp) - #xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - #gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.answer': - print "Call answered" - sdp = ev['content']['answer']['sdp'] - if ev['room_id'] not in xmppClients: - print "We didn't have a call for that room" - continue - # should probably check call ID too - xmppCli = xmppClients[ev['room_id']] - xmppCli.sendAnswer(sdp) - elif ev['type'] == 'm.call.hangup': - if ev['room_id'] in xmppClients: - xmppClients[ev['room_id']].stop() - del xmppClients[ev['room_id']] - + while True: + ev = matrixCli.getEvent() + print ev + if ev['type'] == 'm.room.member': + print 'membership event' + if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME: + roomId = ev['room_id'] + print "joining room %s" % (roomId) + matrixCli.joinRoom(roomId) + elif ev['type'] == 'm.room.message': + if ev['room_id'] in xmppClients: + print "already have a bridge for that user, ignoring" + continue + print "got message, connecting" + xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) + gevent.spawn(xmppClients[ev['room_id']].xmppLoop) + elif ev['type'] == 'm.call.invite': + print "Incoming call" + #sdp = ev['content']['offer']['sdp'] + #print "sdp: %s" % (sdp) + #xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) + #gevent.spawn(xmppClients[ev['room_id']].xmppLoop) + elif ev['type'] == 'm.call.answer': + print "Call answered" + sdp = ev['content']['answer']['sdp'] + if ev['room_id'] not in xmppClients: + print "We didn't have a call for that room" + continue + # should probably check call ID too + xmppCli = xmppClients[ev['room_id']] + xmppCli.sendAnswer(sdp) + elif ev['type'] == 'm.call.hangup': + if ev['room_id'] in xmppClients: + xmppClients[ev['room_id']].stop() + del xmppClients[ev['room_id']] + class TrivialXmppClient: - def __init__(self, matrixRoom, userId): - self.rid = 0 - self.matrixRoom = matrixRoom - self.userId = userId - self.running = True - - def stop(self): - self.running = False - - def nextRid(self): - self.rid += 1 - return '%d' % (self.rid) - - def sendIq(self, xml): - fullXml = "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s'>%s</body>" % (self.nextRid(), self.sid, xml) - #print "\t>>>%s" % (fullXml) - return self.xmppPoke(fullXml) - - def xmppPoke(self, xml): - headers = {'Content-Type': 'application/xml'} - req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml) - resps = grequests.map([req]) - obj = BeautifulSoup(resps[0].content) - return obj - - def sendAnswer(self, answer): - print "sdp from matrix client",answer - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) - jingle, out_err = p.communicate(answer) - jingle = jingle % { - 'tojid': self.callfrom, - 'action': 'session-accept', - 'initiator': self.callfrom, - 'responder': self.jid, - 'sid': self.callsid - } - print "answer jingle from sdp",jingle - res = self.sendIq(jingle) - print "reply from answer: ",res - - self.ssrcs = {} - jingleSoup = BeautifulSoup(jingle) - for cont in jingleSoup.iq.jingle.findAll('content'): - if cont.description: - self.ssrcs[cont['name']] = cont.description['ssrc'] - print "my ssrcs:",self.ssrcs - - gevent.joinall([ - gevent.spawn(self.advertiseSsrcs) - ]) - - def advertiseSsrcs(self): + def __init__(self, matrixRoom, userId): + self.rid = 0 + self.matrixRoom = matrixRoom + self.userId = userId + self.running = True + + def stop(self): + self.running = False + + def nextRid(self): + self.rid += 1 + return '%d' % (self.rid) + + def sendIq(self, xml): + fullXml = "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s'>%s</body>" % (self.nextRid(), self.sid, xml) + #print "\t>>>%s" % (fullXml) + return self.xmppPoke(fullXml) + + def xmppPoke(self, xml): + headers = {'Content-Type': 'application/xml'} + req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml) + resps = grequests.map([req]) + obj = BeautifulSoup(resps[0].content) + return obj + + def sendAnswer(self, answer): + print "sdp from matrix client",answer + p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + jingle, out_err = p.communicate(answer) + jingle = jingle % { + 'tojid': self.callfrom, + 'action': 'session-accept', + 'initiator': self.callfrom, + 'responder': self.jid, + 'sid': self.callsid + } + print "answer jingle from sdp",jingle + res = self.sendIq(jingle) + print "reply from answer: ",res + + self.ssrcs = {} + jingleSoup = BeautifulSoup(jingle) + for cont in jingleSoup.iq.jingle.findAll('content'): + if cont.description: + self.ssrcs[cont['name']] = cont.description['ssrc'] + print "my ssrcs:",self.ssrcs + + gevent.joinall([ + gevent.spawn(self.advertiseSsrcs) + ]) + + def advertiseSsrcs(self): time.sleep(7) - print "SSRC spammer started" - while self.running: - ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] } - res = self.sendIq(ssrcMsg) - print "reply from ssrc announce: ",res - time.sleep(10) - - - - def xmppLoop(self): - self.matrixCallId = time.time() - res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' to='%s' xml:lang='en' wait='60' hold='1' content='text/xml; charset=utf-8' ver='1.6' xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), HOST)) - - print res - self.sid = res.body['sid'] - print "sid %s" % (self.sid) - - res = self.sendIq("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='ANONYMOUS'/>") - - res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s' to='%s' xml:lang='en' xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), self.sid, HOST)) - - res = self.sendIq("<iq type='set' id='_bind_auth_2' xmlns='jabber:client'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>") - print res - - self.jid = res.body.iq.bind.jid.string - print "jid: %s" % (self.jid) - self.shortJid = self.jid.split('-')[0] - - res = self.sendIq("<iq type='set' id='_session_auth_2' xmlns='jabber:client'><session xmlns='urn:ietf:params:xml:ns:xmpp-session'/></iq>") - - #randomthing = res.body.iq['to'] - #whatsitpart = randomthing.split('-')[0] - - #print "other random bind thing: %s" % (randomthing) - - # advertise preence to the jitsi room, with our nick - res = self.sendIq("<iq type='get' to='%s' xmlns='jabber:client' id='1:sendIQ'><services xmlns='urn:xmpp:extdisco:1'><service host='%s'/></services></iq><presence to='%s@%s/d98f6c40' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%s</nick></presence>" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId)) - self.muc = {'users': []} - for p in res.body.findAll('presence'): - u = {} - u['shortJid'] = p['from'].split('/')[1] - if p.c and p.c.nick: - u['nick'] = p.c.nick.string - self.muc['users'].append(u) - print "muc: ",self.muc - - # wait for stuff - while True: - print "waiting..." - res = self.sendIq("") - print "got from stream: ",res - if res.body.iq: - jingles = res.body.iq.findAll('jingle') - if len(jingles): - self.callfrom = res.body.iq['from'] - self.handleInvite(jingles[0]) - elif 'type' in res.body and res.body['type'] == 'terminate': - self.running = False - del xmppClients[self.matrixRoom] - return - - def handleInvite(self, jingle): - self.initiator = jingle['initiator'] - self.callsid = jingle['sid'] - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) - print "raw jingle invite",str(jingle) - sdp, out_err = p.communicate(str(jingle)) - print "transformed remote offer sdp",sdp - inviteEvent = { - 'offer': { - 'type': 'offer', - 'sdp': sdp - }, - 'call_id': self.matrixCallId, - 'version': 0, - 'lifetime': 30000 - } - matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent) - + print "SSRC spammer started" + while self.running: + ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] } + res = self.sendIq(ssrcMsg) + print "reply from ssrc announce: ",res + time.sleep(10) + + + + def xmppLoop(self): + self.matrixCallId = time.time() + res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' to='%s' xml:lang='en' wait='60' hold='1' content='text/xml; charset=utf-8' ver='1.6' xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), HOST)) + + print res + self.sid = res.body['sid'] + print "sid %s" % (self.sid) + + res = self.sendIq("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='ANONYMOUS'/>") + + res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s' to='%s' xml:lang='en' xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), self.sid, HOST)) + + res = self.sendIq("<iq type='set' id='_bind_auth_2' xmlns='jabber:client'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>") + print res + + self.jid = res.body.iq.bind.jid.string + print "jid: %s" % (self.jid) + self.shortJid = self.jid.split('-')[0] + + res = self.sendIq("<iq type='set' id='_session_auth_2' xmlns='jabber:client'><session xmlns='urn:ietf:params:xml:ns:xmpp-session'/></iq>") + + #randomthing = res.body.iq['to'] + #whatsitpart = randomthing.split('-')[0] + + #print "other random bind thing: %s" % (randomthing) + + # advertise preence to the jitsi room, with our nick + res = self.sendIq("<iq type='get' to='%s' xmlns='jabber:client' id='1:sendIQ'><services xmlns='urn:xmpp:extdisco:1'><service host='%s'/></services></iq><presence to='%s@%s/d98f6c40' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%s</nick></presence>" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId)) + self.muc = {'users': []} + for p in res.body.findAll('presence'): + u = {} + u['shortJid'] = p['from'].split('/')[1] + if p.c and p.c.nick: + u['nick'] = p.c.nick.string + self.muc['users'].append(u) + print "muc: ",self.muc + + # wait for stuff + while True: + print "waiting..." + res = self.sendIq("") + print "got from stream: ",res + if res.body.iq: + jingles = res.body.iq.findAll('jingle') + if len(jingles): + self.callfrom = res.body.iq['from'] + self.handleInvite(jingles[0]) + elif 'type' in res.body and res.body['type'] == 'terminate': + self.running = False + del xmppClients[self.matrixRoom] + return + + def handleInvite(self, jingle): + self.initiator = jingle['initiator'] + self.callsid = jingle['sid'] + p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + print "raw jingle invite",str(jingle) + sdp, out_err = p.communicate(str(jingle)) + print "transformed remote offer sdp",sdp + inviteEvent = { + 'offer': { + 'type': 'offer', + 'sdp': sdp + }, + 'call_id': self.matrixCallId, + 'version': 0, + 'lifetime': 30000 + } + matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent) + matrixCli = TrivialMatrixClient(ACCESS_TOKEN) gevent.joinall([ - gevent.spawn(matrixLoop) + gevent.spawn(matrixLoop) ]) diff --git a/setup.py b/setup.py index 043cd044a7..3249e87a96 100755 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ setup( install_requires=[ "syutil==0.0.2", "matrix_angular_sdk==0.6.0", - "Twisted>=14.0.0", + "Twisted==14.0.2", "service_identity>=1.0.0", "pyopenssl>=0.14", "pyyaml", diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a342a0e0da..9c03024512 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor -from synapse.types import UserID +from synapse.types import UserID, ClientInfo import logging @@ -290,7 +290,9 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - UserID : User ID object of the user making the request + tuple : of UserID and device string: + User ID object of the user making the request + Client ID object of the client instance the user is using Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -299,6 +301,8 @@ class Auth(object): access_token = request.args["access_token"][0] user_info = yield self.get_user_by_token(access_token) user = user_info["user"] + device_id = user_info["device_id"] + token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( @@ -314,7 +318,7 @@ class Auth(object): user_agent=user_agent ) - defer.returnValue(user) + defer.returnValue((user, ClientInfo(device_id, token_id))) except KeyError: raise AuthError(403, "Missing access token.") @@ -339,6 +343,7 @@ class Auth(object): "admin": bool(ret.get("admin", False)), "device_id": ret.get("device_id"), "user": UserID.from_string(ret.get("name")), + "token_id": ret.get("token_id", None), } defer.returnValue(user_info) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 2b049debf3..ad478aa6b7 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) class Codes(object): + UNRECOGNIZED = "M_UNRECOGNIZED" UNAUTHORIZED = "M_UNAUTHORIZED" FORBIDDEN = "M_FORBIDDEN" BAD_JSON = "M_BAD_JSON" @@ -34,6 +35,7 @@ class Codes(object): LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" + MISSING_PARAM = "M_MISSING_PARAM", TOO_LARGE = "M_TOO_LARGE" @@ -81,6 +83,35 @@ class RegistrationError(SynapseError): pass +class UnrecognizedRequestError(SynapseError): + """An error indicating we don't understand the request you're trying to make""" + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.UNRECOGNIZED + message = None + if len(args) == 0: + message = "Unrecognized request" + else: + message = args[0] + super(UnrecognizedRequestError, self).__init__( + 400, + message, + **kwargs + ) + + +class NotFoundError(SynapseError): + """An error indicating we can't find the thing you asked for""" + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.NOT_FOUND + super(NotFoundError, self).__init__( + 404, + "Not found", + **kwargs + ) + + class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" diff --git a/synapse/api/urls.py b/synapse/api/urls.py index a299392049..693c0efda6 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -16,6 +16,7 @@ """Contains the URL paths to prefix various aspects of the server with. """ CLIENT_PREFIX = "/_matrix/client/api/v1" +CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" FEDERATION_PREFIX = "/_matrix/federation/v1" WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index fabe8ddacb..f20ccfb5b6 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -32,12 +32,13 @@ from synapse.http.server_key_resource import LocalKey from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, - SERVER_KEY_PREFIX, MEDIA_PREFIX + SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, ) from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.rest.client.v1 import ClientV1RestResource +from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource from daemonize import Daemonize import twisted.manhole.telnet @@ -62,6 +63,9 @@ class SynapseHomeServer(HomeServer): def build_resource_for_client(self): return ClientV1RestResource(self) + def build_resource_for_client_v2_alpha(self): + return ClientV2AlphaRestResource(self) + def build_resource_for_federation(self): return JsonResource() @@ -105,6 +109,7 @@ class SynapseHomeServer(HomeServer): # [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ] desired_tree = [ (CLIENT_PREFIX, self.get_resource_for_client()), + (CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()), (FEDERATION_PREFIX, self.get_resource_for_federation()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()), @@ -267,6 +272,8 @@ def setup(): bind_port = None hs.start_listening(bind_port, config.unsecure_port) + hs.get_pusherpool().start() + if config.daemonize: print config.pid_file daemon = Daemonize( diff --git a/synapse/events/utils.py b/synapse/events/utils.py index bcb5457278..e391aca4cc 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -89,31 +89,31 @@ def prune_event(event): return type(event)(allowed_fields) -def serialize_event(hs, e, client_event=True): +def serialize_event(e, time_now_ms, client_event=True): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e + time_now_ms = int(time_now_ms) + # Should this strip out None's? d = {k: v for k, v in e.get_dict().items()} if not client_event: # set the age and keep all other keys if "age_ts" in d["unsigned"]: - now = int(hs.get_clock().time_msec()) - d["unsigned"]["age"] = now - d["unsigned"]["age_ts"] + d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"] return d if "age_ts" in d["unsigned"]: - now = int(hs.get_clock().time_msec()) - d["age"] = now - d["unsigned"]["age_ts"] + d["age"] = time_now_ms - d["unsigned"]["age_ts"] del d["unsigned"]["age_ts"] d["user_id"] = d.pop("sender", None) if "redacted_because" in e.unsigned: d["redacted_because"] = serialize_event( - hs, e.unsigned["redacted_because"] + e.unsigned["redacted_because"], time_now_ms ) del d["unsigned"]["redacted_because"] diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 01e67b0818..025e7e7e62 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function from synapse.types import UserID +from synapse.events.utils import serialize_event from ._base import BaseHandler @@ -48,24 +49,25 @@ class EventStreamHandler(BaseHandler): @defer.inlineCallbacks @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, - as_client_event=True): + as_client_event=True, affect_presence=True): auth_user = UserID.from_string(auth_user_id) try: - if auth_user not in self._streams_per_user: - self._streams_per_user[auth_user] = 0 - if auth_user in self._stop_timer_per_user: - try: - self.clock.cancel_call_later( - self._stop_timer_per_user.pop(auth_user) + if affect_presence: + if auth_user not in self._streams_per_user: + self._streams_per_user[auth_user] = 0 + if auth_user in self._stop_timer_per_user: + try: + self.clock.cancel_call_later( + self._stop_timer_per_user.pop(auth_user) + ) + except: + logger.exception("Failed to cancel event timer") + else: + yield self.distributor.fire( + "started_user_eventstream", auth_user ) - except: - logger.exception("Failed to cancel event timer") - else: - yield self.distributor.fire( - "started_user_eventstream", auth_user - ) - self._streams_per_user[auth_user] += 1 + self._streams_per_user[auth_user] += 1 if pagin_config.from_token is None: pagin_config.from_token = None @@ -78,8 +80,10 @@ class EventStreamHandler(BaseHandler): auth_user, room_ids, pagin_config, timeout ) + time_now = self.clock.time_msec() + chunks = [ - self.hs.serialize_event(e, as_client_event) for e in events + serialize_event(e, time_now, as_client_event) for e in events ] chunk = { @@ -91,27 +95,28 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) finally: - self._streams_per_user[auth_user] -= 1 - if not self._streams_per_user[auth_user]: - del self._streams_per_user[auth_user] - - # 10 seconds of grace to allow the client to reconnect again - # before we think they're gone - def _later(): - logger.debug( - "_later stopped_user_eventstream %s", auth_user - ) + if affect_presence: + self._streams_per_user[auth_user] -= 1 + if not self._streams_per_user[auth_user]: + del self._streams_per_user[auth_user] + + # 10 seconds of grace to allow the client to reconnect again + # before we think they're gone + def _later(): + logger.debug( + "_later stopped_user_eventstream %s", auth_user + ) - self._stop_timer_per_user.pop(auth_user, None) + self._stop_timer_per_user.pop(auth_user, None) - yield self.distributor.fire( - "stopped_user_eventstream", auth_user - ) + return self.distributor.fire( + "stopped_user_eventstream", auth_user + ) - logger.debug("Scheduling _later: for %s", auth_user) - self._stop_timer_per_user[auth_user] = ( - self.clock.call_later(30, _later) - ) + logger.debug("Scheduling _later: for %s", auth_user) + self._stop_timer_per_user[auth_user] = ( + self.clock.call_later(30, _later) + ) class EventHandler(BaseHandler): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6a1104a890..6fbd2af4ab 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import RoomError from synapse.streams.config import PaginationConfig +from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util.logcontext import PreserveLoggingContext from synapse.types import UserID @@ -100,9 +101,11 @@ class MessageHandler(BaseHandler): "room_key", next_key ) + time_now = self.clock.time_msec() + chunk = { "chunk": [ - self.hs.serialize_event(e, as_client_event) for e in events + serialize_event(e, time_now, as_client_event) for e in events ], "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), @@ -111,7 +114,8 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True): + def create_and_send_event(self, event_dict, ratelimit=True, + client=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -145,6 +149,15 @@ class MessageHandler(BaseHandler): builder.content ) + if client is not None: + if client.token_id is not None: + builder.internal_metadata.token_id = client.token_id + if client.device_id is not None: + builder.internal_metadata.device_id = client.device_id + + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id + event, context = yield self._create_new_client_event( builder=builder, ) @@ -211,7 +224,8 @@ class MessageHandler(BaseHandler): # TODO: This is duplicating logic from snapshot_all_rooms current_state = yield self.state_handler.get_current_state(room_id) - defer.returnValue([self.hs.serialize_event(c) for c in current_state]) + now = self.clock.time_msec() + defer.returnValue([serialize_event(c, now) for c in current_state]) @defer.inlineCallbacks def snapshot_all_rooms(self, user_id=None, pagin_config=None, @@ -283,10 +297,11 @@ class MessageHandler(BaseHandler): start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) + time_now = self.clock.time_msec() d["messages"] = { "chunk": [ - self.hs.serialize_event(m, as_client_event) + serialize_event(m, time_now, as_client_event) for m in messages ], "start": start_token.to_string(), @@ -297,7 +312,8 @@ class MessageHandler(BaseHandler): event.room_id ) d["state"] = [ - self.hs.serialize_event(c) for c in current_state + serialize_event(c, time_now, as_client_event) + for c in current_state ] except: logger.exception("Failed to get snapshot") @@ -320,8 +336,9 @@ class MessageHandler(BaseHandler): auth_user = UserID.from_string(user_id) # TODO: These concurrently + time_now = self.clock.time_msec() state_tuples = yield self.state_handler.get_current_state(room_id) - state = [self.hs.serialize_event(x) for x in state_tuples] + state = [serialize_event(x, time_now) for x in state_tuples] member_event = (yield self.store.get_room_member( user_id=user_id, @@ -360,11 +377,13 @@ class MessageHandler(BaseHandler): "Failed to get member presence of %r", m.user_id ) + time_now = self.clock.time_msec() + defer.returnValue({ "membership": member_event.membership, "room_id": room_id, "messages": { - "chunk": [self.hs.serialize_event(m) for m in messages], + "chunk": [serialize_event(m, time_now) for m in messages], "start": start_token.to_string(), "end": end_token.to_string(), }, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d66bfea7b1..cd0798c2b0 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -87,6 +87,10 @@ class PresenceHandler(BaseHandler): "changed_presencelike_data", self.changed_presencelike_data ) + # outbound signal from the presence module to advertise when a user's + # presence has changed + distributor.declare("user_presence_changed") + self.distributor = distributor self.federation = hs.get_replication_layer() @@ -604,6 +608,7 @@ class PresenceHandler(BaseHandler): room_ids=room_ids, statuscache=statuscache, ) + yield self.distributor.fire("user_presence_changed", user, statuscache) @defer.inlineCallbacks def _push_presence_remote(self, user, destination, state=None): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 732652c228..66a89c10b2 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -163,7 +163,7 @@ class RegistrationHandler(BaseHandler): # each request httpCli = SimpleHttpClient(self.hs) # XXX: make this configurable! - trustedIdServers = ['matrix.org:8090'] + trustedIdServers = ['matrix.org:8090', 'matrix.org'] if not creds['idServer'] in trustedIdServers: logger.warn('%s is not a trusted ID server: rejecting 3pid ' + 'credentials', creds['idServer']) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index edb96cec83..23821d321f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -16,12 +16,14 @@ """Contains functions for performing events on rooms.""" from twisted.internet import defer +from ._base import BaseHandler + from synapse.types import UserID, RoomAlias, RoomID from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import StoreError, SynapseError from synapse.util import stringutils from synapse.util.async import run_on_reactor -from ._base import BaseHandler +from synapse.events.utils import serialize_event import logging @@ -293,8 +295,9 @@ class RoomMemberHandler(BaseHandler): yield self.auth.check_joined_room(room_id, user_id) member_list = yield self.store.get_room_members(room_id=room_id) + time_now = self.clock.time_msec() event_list = [ - self.hs.serialize_event(entry) + serialize_event(entry, time_now) for entry in member_list ] chunk_data = { diff --git a/synapse/http/client.py b/synapse/http/client.py index 7793bab106..198f575cfa 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -63,6 +63,25 @@ class SimpleHttpClient(object): defer.returnValue(json.loads(body)) @defer.inlineCallbacks + def post_json_get_json(self, uri, post_json): + json_str = json.dumps(post_json) + + logger.info("HTTP POST %s -> %s", json_str, uri) + + response = yield self.agent.request( + "POST", + uri.encode("ascii"), + headers=Headers({ + "Content-Type": ["application/json"] + }), + bodyProducer=FileBodyProducer(StringIO(json_str)) + ) + + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + + @defer.inlineCallbacks def get_json(self, uri, args={}): """ Get's some json from the given host and path diff --git a/synapse/http/server.py b/synapse/http/server.py index 8015a22edf..0f6539e1be 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,7 +16,7 @@ from synapse.http.agent_name import AGENT_NAME from synapse.api.errors import ( - cs_exception, SynapseError, CodeMessageException + cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError ) from synapse.util.logcontext import LoggingContext @@ -139,11 +139,7 @@ class JsonResource(HttpServer, resource.Resource): return # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - self._send_response( - request, - 400, - {"error": "Unrecognized request"} - ) + raise UnrecognizedRequestError() except CodeMessageException as e: if isinstance(e, SynapseError): logger.info("%s SynapseError: %s - %s", request, e.code, e.msg) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index d5ccf2742f..a4eb6c817c 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -15,6 +15,8 @@ """ This module contains base REST classes for constructing REST servlets. """ +from synapse.api.errors import SynapseError + import logging @@ -54,3 +56,58 @@ class RestServlet(object): http_server.register_path(method, pattern, method_handler) else: raise NotImplementedError("RestServlet must register something.") + + @staticmethod + def parse_integer(request, name, default=None, required=False): + if name in request.args: + try: + return int(request.args[name][0]) + except: + message = "Query parameter %r must be an integer" % (name,) + raise SynapseError(400, message) + else: + if required: + message = "Missing integer query parameter %r" % (name,) + raise SynapseError(400, message) + else: + return default + + @staticmethod + def parse_boolean(request, name, default=None, required=False): + if name in request.args: + try: + return { + "true": True, + "false": False, + }[request.args[name][0]] + except: + message = ( + "Boolean query parameter %r must be one of" + " ['true', 'false']" + ) % (name,) + raise SynapseError(400, message) + else: + if required: + message = "Missing boolean query parameter %r" % (name,) + raise SynapseError(400, message) + else: + return default + + @staticmethod + def parse_string(request, name, default=None, required=False, + allowed_values=None, param_type="string"): + if name in request.args: + value = request.args[name][0] + if allowed_values is not None and value not in allowed_values: + message = "Query parameter %r must be one of [%s]" % ( + name, ", ".join(repr(v) for v in allowed_values) + ) + raise SynapseError(message) + else: + return value + else: + if required: + message = "Missing %s query parameter %r" % (param_type, name) + raise SynapseError(400, message) + else: + return default diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py new file mode 100644 index 0000000000..19478c72a2 --- /dev/null +++ b/synapse/push/__init__.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.streams.config import PaginationConfig +from synapse.types import StreamToken, UserID + +import synapse.util.async +import baserules + +import logging +import fnmatch +import json + +logger = logging.getLogger(__name__) + + +class Pusher(object): + INITIAL_BACKOFF = 1000 + MAX_BACKOFF = 60 * 60 * 1000 + GIVE_UP_AFTER = 24 * 60 * 60 * 1000 + DEFAULT_ACTIONS = ['notify'] + + def __init__(self, _hs, instance_handle, user_name, app_id, + app_display_name, device_display_name, pushkey, pushkey_ts, + data, last_token, last_success, failing_since): + self.hs = _hs + self.evStreamHandler = self.hs.get_handlers().event_stream_handler + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() + self.instance_handle = instance_handle + self.user_name = user_name + self.app_id = app_id + self.app_display_name = app_display_name + self.device_display_name = device_display_name + self.pushkey = pushkey + self.pushkey_ts = pushkey_ts + self.data = data + self.last_token = last_token + self.last_success = last_success # not actually used + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.failing_since = failing_since + self.alive = True + + # The last value of last_active_time that we saw + self.last_last_active_time = 0 + self.has_unread = True + + @defer.inlineCallbacks + def _actions_for_event(self, ev): + """ + This should take into account notification settings that the user + has configured both globally and per-room when we have the ability + to do such things. + """ + if ev['user_id'] == self.user_name: + # let's assume you probably know about messages you sent yourself + defer.returnValue(['dont_notify']) + + if ev['type'] == 'm.room.member': + if ev['state_key'] != self.user_name: + defer.returnValue(['dont_notify']) + + rules = yield self.store.get_push_rules_for_user_name(self.user_name) + + for r in rules: + r['conditions'] = json.loads(r['conditions']) + r['actions'] = json.loads(r['actions']) + + user_name_localpart = UserID.from_string(self.user_name).localpart + + rules.extend(baserules.make_base_rules(user_name_localpart)) + + # get *our* member event for display name matching + member_events_for_room = yield self.store.get_current_state( + room_id=ev['room_id'], + event_type='m.room.member', + state_key=self.user_name + ) + my_display_name = None + if len(member_events_for_room) > 0: + my_display_name = member_events_for_room[0].content['displayname'] + + for r in rules: + matches = True + + conditions = r['conditions'] + actions = r['actions'] + + for c in conditions: + matches &= self._event_fulfills_condition( + ev, c, display_name=my_display_name + ) + # ignore rules with no actions (we have an explict 'dont_notify' + if len(actions) == 0: + logger.warn( + "Ignoring rule id %s with no actions for user %s" % + (r['rule_id'], r['user_name']) + ) + continue + if matches: + defer.returnValue(actions) + + defer.returnValue(Pusher.DEFAULT_ACTIONS) + + def _event_fulfills_condition(self, ev, condition, display_name): + if condition['kind'] == 'event_match': + if 'pattern' not in condition: + logger.warn("event_match condition with no pattern") + return False + pat = condition['pattern'] + + val = _value_for_dotted_key(condition['key'], ev) + if val is None: + return False + return fnmatch.fnmatch(val.upper(), pat.upper()) + elif condition['kind'] == 'device': + if 'instance_handle' not in condition: + return True + return condition['instance_handle'] == self.instance_handle + elif condition['kind'] == 'contains_display_name': + # This is special because display names can be different + # between rooms and so you can't really hard code it in a rule. + # Optimisation: we should cache these names and update them from + # the event stream. + if 'content' not in ev or 'body' not in ev['content']: + return False + return fnmatch.fnmatch( + ev['content']['body'].upper(), "*%s*" % (display_name.upper(),) + ) + else: + return True + + @defer.inlineCallbacks + def get_context_for_event(self, ev): + name_aliases = yield self.store.get_room_name_and_aliases( + ev['room_id'] + ) + + ctx = {'aliases': name_aliases[1]} + if name_aliases[0] is not None: + ctx['name'] = name_aliases[0] + + their_member_events_for_room = yield self.store.get_current_state( + room_id=ev['room_id'], + event_type='m.room.member', + state_key=ev['user_id'] + ) + if len(their_member_events_for_room) > 0: + dn = their_member_events_for_room[0].content['displayname'] + if dn is not None: + ctx['sender_display_name'] = dn + + defer.returnValue(ctx) + + @defer.inlineCallbacks + def start(self): + if not self.last_token: + # First-time setup: get a token to start from (we can't + # just start from no token, ie. 'now' + # because we need the result to be reproduceable in case + # we fail to dispatch the push) + config = PaginationConfig(from_token=None, limit='1') + chunk = yield self.evStreamHandler.get_stream( + self.user_name, config, timeout=0) + self.last_token = chunk['end'] + self.store.update_pusher_last_token( + self.user_name, self.pushkey, self.last_token) + logger.info("Pusher %s for user %s starting from token %s", + self.pushkey, self.user_name, self.last_token) + + while self.alive: + from_tok = StreamToken.from_string(self.last_token) + config = PaginationConfig(from_token=from_tok, limit='1') + chunk = yield self.evStreamHandler.get_stream( + self.user_name, config, + timeout=100*365*24*60*60*1000, affect_presence=False + ) + + # limiting to 1 may get 1 event plus 1 presence event, so + # pick out the actual event + single_event = None + for c in chunk['chunk']: + if 'event_id' in c: # Hmmm... + single_event = c + break + if not single_event: + self.last_token = chunk['end'] + continue + + if not self.alive: + continue + + processed = False + actions = yield self._actions_for_event(single_event) + tweaks = _tweaks_for_actions(actions) + + if len(actions) == 0: + logger.warn("Empty actions! Using default action.") + actions = Pusher.DEFAULT_ACTIONS + if 'notify' not in actions and 'dont_notify' not in actions: + logger.warn("Neither notify nor dont_notify in actions: adding default") + actions.extend(Pusher.DEFAULT_ACTIONS) + if 'dont_notify' in actions: + logger.debug( + "%s for %s: dont_notify", + single_event['event_id'], self.user_name + ) + processed = True + else: + rejected = yield self.dispatch_push(single_event, tweaks) + self.has_unread = True + if isinstance(rejected, list) or isinstance(rejected, tuple): + processed = True + for pk in rejected: + if pk != self.pushkey: + # for sanity, we only remove the pushkey if it + # was the one we actually sent... + logger.warn( + ("Ignoring rejected pushkey %s because we" + " didn't send it"), pk + ) + else: + logger.info( + "Pushkey %s was rejected: removing", + pk + ) + yield self.hs.get_pusherpool().remove_pusher( + self.app_id, pk + ) + + if not self.alive: + continue + + if processed: + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.last_token = chunk['end'] + self.store.update_pusher_last_token_and_success( + self.user_name, + self.pushkey, + self.last_token, + self.clock.time_msec() + ) + if self.failing_since: + self.failing_since = None + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since) + else: + if not self.failing_since: + self.failing_since = self.clock.time_msec() + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since + ) + + if (self.failing_since and + self.failing_since < + self.clock.time_msec() - Pusher.GIVE_UP_AFTER): + # we really only give up so that if the URL gets + # fixed, we don't suddenly deliver a load + # of old notifications. + logger.warn("Giving up on a notification to user %s, " + "pushkey %s", + self.user_name, self.pushkey) + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.last_token = chunk['end'] + self.store.update_pusher_last_token( + self.user_name, + self.pushkey, + self.last_token + ) + + self.failing_since = None + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since + ) + else: + logger.warn("Failed to dispatch push for user %s " + "(failing for %dms)." + "Trying again in %dms", + self.user_name, + self.clock.time_msec() - self.failing_since, + self.backoff_delay) + yield synapse.util.async.sleep(self.backoff_delay / 1000.0) + self.backoff_delay *= 2 + if self.backoff_delay > Pusher.MAX_BACKOFF: + self.backoff_delay = Pusher.MAX_BACKOFF + + def stop(self): + self.alive = False + + def dispatch_push(self, p, tweaks): + """ + Overridden by implementing classes to actually deliver the notification + Args: + p: The event to notify for as a single event from the event stream + Returns: If the notification was delivered, an array containing any + pushkeys that were rejected by the push gateway. + False if the notification could not be delivered (ie. + should be retried). + """ + pass + + def reset_badge_count(self): + pass + + def presence_changed(self, state): + """ + We clear badge counts whenever a user's last_active time is bumped + This is by no means perfect but I think it's the best we can do + without read receipts. + """ + if 'last_active' in state.state: + last_active = state.state['last_active'] + if last_active > self.last_last_active_time: + self.last_last_active_time = last_active + if self.has_unread: + logger.info("Resetting badge count for %s", self.user_name) + self.reset_badge_count() + self.has_unread = False + + +def _value_for_dotted_key(dotted_key, event): + parts = dotted_key.split(".") + val = event + while len(parts) > 0: + if parts[0] not in val: + return None + val = val[parts[0]] + parts = parts[1:] + return val + + +def _tweaks_for_actions(actions): + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if 'set_sound' in a: + tweaks['sound'] = a['set_sound'] + return tweaks + + +class PusherConfigException(Exception): + def __init__(self, msg): + super(PusherConfigException, self).__init__(msg) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py new file mode 100644 index 0000000000..4caf7beed2 --- /dev/null +++ b/synapse/push/baserules.py @@ -0,0 +1,35 @@ +def make_base_rules(user_name): + """ + Nominally we reserve priority class 0 for these rules, although + in practice we just append them to the end so we don't actually need it. + """ + return [ + { + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': '*%s*' % (user_name,), # Matrix ID match + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + }, + { + 'conditions': [ + { + 'kind': 'contains_display_name' + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + }, + ] \ No newline at end of file diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py new file mode 100644 index 0000000000..d4c5f03b01 --- /dev/null +++ b/synapse/push/httppusher.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.push import Pusher, PusherConfigException +from synapse.http.client import SimpleHttpClient + +from twisted.internet import defer + +import logging + +logger = logging.getLogger(__name__) + + +class HttpPusher(Pusher): + def __init__(self, _hs, instance_handle, user_name, app_id, + app_display_name, device_display_name, pushkey, pushkey_ts, + data, last_token, last_success, failing_since): + super(HttpPusher, self).__init__( + _hs, + instance_handle, + user_name, + app_id, + app_display_name, + device_display_name, + pushkey, + pushkey_ts, + data, + last_token, + last_success, + failing_since + ) + if 'url' not in data: + raise PusherConfigException( + "'url' required in data for HTTP pusher" + ) + self.url = data['url'] + self.httpCli = SimpleHttpClient(self.hs) + self.data_minus_url = {} + self.data_minus_url.update(self.data) + del self.data_minus_url['url'] + + @defer.inlineCallbacks + def _build_notification_dict(self, event, tweaks): + # we probably do not want to push for every presence update + # (we may want to be able to set up notifications when specific + # people sign in, but we'd want to only deliver the pertinent ones) + # Actually, presence events will not get this far now because we + # need to filter them out in the main Pusher code. + if 'event_id' not in event: + defer.returnValue(None) + + ctx = yield self.get_context_for_event(event) + + d = { + 'notification': { + 'id': event['event_id'], + 'type': event['type'], + 'sender': event['user_id'], + 'counts': { # -- we don't mark messages as read yet so + # we have no way of knowing + # Just set the badge to 1 until we have read receipts + 'unread': 1, + # 'missed_calls': 2 + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + 'tweaks': tweaks + } + ] + } + } + if event['type'] == 'm.room.member': + d['notification']['membership'] = event['content']['membership'] + if 'content' in event: + d['notification']['content'] = event['content'] + + if len(ctx['aliases']): + d['notification']['room_alias'] = ctx['aliases'][0] + if 'sender_display_name' in ctx: + d['notification']['sender_display_name'] = ctx['sender_display_name'] + if 'name' in ctx: + d['notification']['room_name'] = ctx['name'] + + defer.returnValue(d) + + @defer.inlineCallbacks + def dispatch_push(self, event, tweaks): + notification_dict = yield self._build_notification_dict(event, tweaks) + if not notification_dict: + defer.returnValue([]) + try: + resp = yield self.httpCli.post_json_get_json(self.url, notification_dict) + except: + logger.exception("Failed to push %s ", self.url) + defer.returnValue(False) + rejected = [] + if 'rejected' in resp: + rejected = resp['rejected'] + defer.returnValue(rejected) + + @defer.inlineCallbacks + def reset_badge_count(self): + d = { + 'notification': { + 'id': '', + 'type': None, + 'sender': '', + 'counts': { + 'unread': 0, + 'missed_calls': 0 + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + } + ] + } + } + try: + resp = yield self.httpCli.post_json_get_json(self.url, d) + except: + logger.exception("Failed to push %s ", self.url) + defer.returnValue(False) + rejected = [] + if 'rejected' in resp: + rejected = resp['rejected'] + defer.returnValue(rejected) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py new file mode 100644 index 0000000000..4892c21e7b --- /dev/null +++ b/synapse/push/pusherpool.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from httppusher import HttpPusher +from synapse.push import PusherConfigException + +import logging +import json + +logger = logging.getLogger(__name__) + + +class PusherPool: + def __init__(self, _hs): + self.hs = _hs + self.store = self.hs.get_datastore() + self.pushers = {} + self.last_pusher_started = -1 + + distributor = self.hs.get_distributor() + distributor.observe( + "user_presence_changed", self.user_presence_changed + ) + + @defer.inlineCallbacks + def user_presence_changed(self, user, state): + user_name = user.to_string() + + # until we have read receipts, pushers use this to reset a user's + # badge counters to zero + for p in self.pushers.values(): + if p.user_name == user_name: + yield p.presence_changed(state) + + @defer.inlineCallbacks + def start(self): + pushers = yield self.store.get_all_pushers() + for p in pushers: + p['data'] = json.loads(p['data']) + self._start_pushers(pushers) + + @defer.inlineCallbacks + def add_pusher(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data): + # we try to create the pusher just to validate the config: it + # will then get pulled out of the database, + # recreated, added and started: this means we have only one + # code path adding pushers. + self._create_pusher({ + "user_name": user_name, + "kind": kind, + "instance_handle": instance_handle, + "app_id": app_id, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "pushkey": pushkey, + "pushkey_ts": self.hs.get_clock().time_msec(), + "lang": lang, + "data": data, + "last_token": None, + "last_success": None, + "failing_since": None + }) + yield self._add_pusher_to_store( + user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, lang, data + ) + + @defer.inlineCallbacks + def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, lang, data): + yield self.store.add_pusher( + user_name=user_name, + instance_handle=instance_handle, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=self.hs.get_clock().time_msec(), + lang=lang, + data=json.dumps(data) + ) + self._refresh_pusher((app_id, pushkey)) + + def _create_pusher(self, pusherdict): + if pusherdict['kind'] == 'http': + return HttpPusher( + self.hs, + instance_handle=pusherdict['instance_handle'], + user_name=pusherdict['user_name'], + app_id=pusherdict['app_id'], + app_display_name=pusherdict['app_display_name'], + device_display_name=pusherdict['device_display_name'], + pushkey=pusherdict['pushkey'], + pushkey_ts=pusherdict['pushkey_ts'], + data=pusherdict['data'], + last_token=pusherdict['last_token'], + last_success=pusherdict['last_success'], + failing_since=pusherdict['failing_since'] + ) + else: + raise PusherConfigException( + "Unknown pusher type '%s' for user %s" % + (pusherdict['kind'], pusherdict['user_name']) + ) + + @defer.inlineCallbacks + def _refresh_pusher(self, app_id_pushkey): + p = yield self.store.get_pushers_by_app_id_and_pushkey( + app_id_pushkey + ) + p['data'] = json.loads(p['data']) + + self._start_pushers([p]) + + def _start_pushers(self, pushers): + logger.info("Starting %d pushers", len(pushers)) + for pusherdict in pushers: + p = self._create_pusher(pusherdict) + if p: + fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) + if fullid in self.pushers: + self.pushers[fullid].stop() + self.pushers[fullid] = p + p.start() + + @defer.inlineCallbacks + def remove_pusher(self, app_id, pushkey): + fullid = "%s:%s" % (app_id, pushkey) + if fullid in self.pushers: + logger.info("Stopping pusher %s", fullid) + self.pushers[fullid].stop() + del self.pushers[fullid] + yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index b1fae991e0..826a36f203 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -5,8 +5,8 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "syutil==0.0.2": ["syutil"], - "matrix_angular_sdk==0.6.0": ["syweb==0.6.0"], - "Twisted>=14.0.0": ["twisted>=14.0.0"], + "matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"], + "Twisted==14.0.2": ["twisted==14.0.2"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py index 8bb89b2f6a..d8d01cdd16 100644 --- a/synapse/rest/client/v1/__init__.py +++ b/synapse/rest/client/v1/__init__.py @@ -13,10 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - from . import ( room, events, register, login, profile, presence, initial_sync, directory, - voip, admin, + voip, admin, pusher, push_rule ) from synapse.http.server import JsonResource @@ -41,3 +40,5 @@ class ClientV1RestResource(JsonResource): directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource) + pusher.register_servlets(hs, client_resource) + push_rule.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 1051d96f96..2ce754b028 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 15ae8749b8..8f65efec5f 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -45,7 +45,7 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) if not "room_id" in content: @@ -85,7 +85,7 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, room_alias): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index c69de56863..77b7c25a03 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig from .base import ClientV1RestServlet, client_path_pattern +from synapse.events.utils import serialize_event import logging @@ -33,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -64,14 +65,19 @@ class EventStreamRestServlet(ClientV1RestServlet): class EventRestServlet(ClientV1RestServlet): PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") + def __init__(self, hs): + super(EventRestServlet, self).__init__(hs) + self.clock = hs.get_clock() + @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) + time_now = self.clock.time_msec() if event: - defer.returnValue((200, self.hs.serialize_event(event))) + defer.returnValue((200, serialize_event(event, time_now))) else: defer.returnValue((404, "Event not found.")) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 357fa845b4..4a259bba64 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index b6c207e662..7feb4aadb1 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 24f8d56952..15d6f3fc6c 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py new file mode 100644 index 0000000000..0f78fa667c --- /dev/null +++ b/synapse/rest/client/v1/push_rule.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError, \ + StoreError +from .base import ClientV1RestServlet, client_path_pattern +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException + +import json + + +class PushRuleRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/pushrules/.*$") + PRIORITY_CLASS_MAP = { + 'underride': 1, + 'sender': 2, + 'room': 3, + 'content': 4, + 'override': 5, + } + PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()} + SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( + "Unrecognised request: You probably wanted a trailing slash") + + def rule_spec_from_path(self, path): + if len(path) < 2: + raise UnrecognizedRequestError() + if path[0] != 'pushrules': + raise UnrecognizedRequestError() + + scope = path[1] + path = path[2:] + if scope not in ['global', 'device']: + raise UnrecognizedRequestError() + + device = None + if scope == 'device': + if len(path) == 0: + raise UnrecognizedRequestError() + device = path[0] + path = path[1:] + + if len(path) == 0: + raise UnrecognizedRequestError() + + template = path[0] + path = path[1:] + + if len(path) == 0: + raise UnrecognizedRequestError() + + rule_id = path[0] + + spec = { + 'scope': scope, + 'template': template, + 'rule_id': rule_id + } + if device: + spec['device'] = device + return spec + + def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None): + if rule_template in ['override', 'underride']: + if 'conditions' not in req_obj: + raise InvalidRuleException("Missing 'conditions'") + conditions = req_obj['conditions'] + for c in conditions: + if 'kind' not in c: + raise InvalidRuleException("Condition without 'kind'") + elif rule_template == 'room': + conditions = [{ + 'kind': 'event_match', + 'key': 'room_id', + 'pattern': rule_id + }] + elif rule_template == 'sender': + conditions = [{ + 'kind': 'event_match', + 'key': 'user_id', + 'pattern': rule_id + }] + elif rule_template == 'content': + if 'pattern' not in req_obj: + raise InvalidRuleException("Content rule missing 'pattern'") + pat = req_obj['pattern'] + if pat.strip("*?[]") == pat: + # no special glob characters so we assume the user means + # 'contains this string' rather than 'is this string' + pat = "*%s*" % (pat,) + conditions = [{ + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': pat + }] + else: + raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) + + if device: + conditions.append({ + 'kind': 'device', + 'instance_handle': device + }) + + if 'actions' not in req_obj: + raise InvalidRuleException("No actions found") + actions = req_obj['actions'] + + for a in actions: + if a in ['notify', 'dont_notify', 'coalesce']: + pass + elif isinstance(a, dict) and 'set_sound' in a: + pass + else: + raise InvalidRuleException("Unrecognised action") + + return conditions, actions + + @defer.inlineCallbacks + def on_PUT(self, request): + spec = self.rule_spec_from_path(request.postpath) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + user, _ = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + try: + (conditions, actions) = self.rule_tuple_from_request_object( + spec['template'], + spec['rule_id'], + content, + device=spec['device'] if 'device' in spec else None + ) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + before = request.args.get("before", None) + if before and len(before): + before = before[0] + after = request.args.get("after", None) + if after and len(after): + after = after[0] + + try: + yield self.hs.get_datastore().add_push_rule( + user_name=user.to_string(), + rule_id=spec['rule_id'], + priority_class=priority_class, + conditions=conditions, + actions=actions, + before=before, + after=after + ) + except InconsistentRuleException as e: + raise SynapseError(400, e.message) + except RuleNotFoundException as e: + raise SynapseError(400, e.message) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request): + spec = self.rule_spec_from_path(request.postpath) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + user, _ = yield self.auth.get_user_by_req(request) + + if 'device' in spec: + rules = yield self.hs.get_datastore().get_push_rules_for_user_name( + user.to_string() + ) + + for r in rules: + conditions = json.loads(r['conditions']) + ih = _instance_handle_from_conditions(conditions) + if ih == spec['device'] and r['priority_class'] == priority_class: + yield self.hs.get_datastore().delete_push_rule( + user.to_string(), spec['rule_id'] + ) + defer.returnValue((200, {})) + raise NotFoundError() + else: + try: + yield self.hs.get_datastore().delete_push_rule( + user.to_string(), spec['rule_id'], + priority_class=priority_class + ) + defer.returnValue((200, {})) + except StoreError as e: + if e.code == 404: + raise NotFoundError() + else: + raise + + @defer.inlineCallbacks + def on_GET(self, request): + user, _ = yield self.auth.get_user_by_req(request) + + # we build up the full structure and then decide which bits of it + # to send which means doing unnecessary work sometimes but is + # is probably not going to make a whole lot of difference + rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string()) + + rules = {'global': {}, 'device': {}} + + rules['global'] = _add_empty_priority_class_arrays(rules['global']) + + for r in rawrules: + rulearray = None + + r["conditions"] = json.loads(r["conditions"]) + r["actions"] = json.loads(r["actions"]) + + template_name = _priority_class_to_template_name(r['priority_class']) + + if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']: + # per-device rule + instance_handle = _instance_handle_from_conditions(r["conditions"]) + r = _strip_device_condition(r) + if not instance_handle: + continue + if instance_handle not in rules['device']: + rules['device'][instance_handle] = {} + rules['device'][instance_handle] = ( + _add_empty_priority_class_arrays( + rules['device'][instance_handle] + ) + ) + + rulearray = rules['device'][instance_handle][template_name] + else: + rulearray = rules['global'][template_name] + + template_rule = _rule_to_template(r) + if template_rule: + rulearray.append(template_rule) + + path = request.postpath[1:] + + if path == []: + # we're a reference impl: pedantry is our job. + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == '': + defer.returnValue((200, rules)) + elif path[0] == 'global': + path = path[1:] + result = _filter_ruleset_with_path(rules['global'], path) + defer.returnValue((200, result)) + elif path[0] == 'device': + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == '': + defer.returnValue((200, rules['device'])) + + instance_handle = path[0] + path = path[1:] + if instance_handle not in rules['device']: + ret = {} + ret = _add_empty_priority_class_arrays(ret) + defer.returnValue((200, ret)) + ruleset = rules['device'][instance_handle] + result = _filter_ruleset_with_path(ruleset, path) + defer.returnValue((200, result)) + else: + raise UnrecognizedRequestError() + + def on_OPTIONS(self, _): + return 200, {} + + +def _add_empty_priority_class_arrays(d): + for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys(): + d[pc] = [] + return d + + +def _instance_handle_from_conditions(conditions): + """ + Given a list of conditions, return the instance handle of the + device rule if there is one + """ + for c in conditions: + if c['kind'] == 'device': + return c['instance_handle'] + return None + + +def _filter_ruleset_with_path(ruleset, path): + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == '': + return ruleset + template_kind = path[0] + if template_kind not in ruleset: + raise UnrecognizedRequestError() + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == '': + return ruleset[template_kind] + rule_id = path[0] + for r in ruleset[template_kind]: + if r['rule_id'] == rule_id: + return r + raise NotFoundError + + +def _priority_class_from_spec(spec): + if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) + pc = PushRuleRestServlet.PRIORITY_CLASS_MAP[spec['template']] + + if spec['scope'] == 'device': + pc += len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + + return pc + + +def _priority_class_to_template_name(pc): + if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']: + # per-device + prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index] + else: + return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc] + + +def _rule_to_template(rule): + template_name = _priority_class_to_template_name(rule['priority_class']) + if template_name in ['override', 'underride']: + return {k: rule[k] for k in ["rule_id", "conditions", "actions"]} + elif template_name in ["sender", "room"]: + return {k: rule[k] for k in ["rule_id", "actions"]} + elif template_name == 'content': + if len(rule["conditions"]) != 1: + return None + thecond = rule["conditions"][0] + if "pattern" not in thecond: + return None + ret = {k: rule[k] for k in ["rule_id", "actions"]} + ret["pattern"] = thecond["pattern"] + return ret + + +def _strip_device_condition(rule): + for i, c in enumerate(rule['conditions']): + if c['kind'] == 'device': + del rule['conditions'][i] + return rule + + +class InvalidRuleException(Exception): + pass + + +# XXX: C+ped from rest/room.py - surely this should be common? +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.", + errcode=Codes.NOT_JSON) + return content + except ValueError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + +def register_servlets(hs, http_server): + PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py new file mode 100644 index 0000000000..353a4a6589 --- /dev/null +++ b/synapse/rest/client/v1/pusher.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError, Codes +from synapse.push import PusherConfigException +from .base import ClientV1RestServlet, client_path_pattern + +import json + + +class PusherRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/pushers/set$") + + @defer.inlineCallbacks + def on_POST(self, request): + user, _ = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + pusher_pool = self.hs.get_pusherpool() + + if ('pushkey' in content and 'app_id' in content + and 'kind' in content and + content['kind'] is None): + yield pusher_pool.remove_pusher( + content['app_id'], content['pushkey'] + ) + defer.returnValue((200, {})) + + reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name', + 'device_display_name', 'pushkey', 'lang', 'data'] + missing = [] + for i in reqd: + if i not in content: + missing.append(i) + if len(missing): + raise SynapseError(400, "Missing parameters: "+','.join(missing), + errcode=Codes.MISSING_PARAM) + + try: + yield pusher_pool.add_pusher( + user_name=user.to_string(), + instance_handle=content['instance_handle'], + kind=content['kind'], + app_id=content['app_id'], + app_display_name=content['app_display_name'], + device_display_name=content['device_display_name'], + pushkey=content['pushkey'], + lang=content['lang'], + data=content['data'] + ) + except PusherConfigException as pce: + raise SynapseError(400, "Config Error: "+pce.message, + errcode=Codes.MISSING_PARAM) + + defer.returnValue((200, {})) + + def on_OPTIONS(self, _): + return 200, {} + + +# XXX: C+ped from rest/room.py - surely this should be common? +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.", + errcode=Codes.NOT_JSON) + return content + except ValueError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + +def register_servlets(hs, http_server): + PusherRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f06e3ddb98..410f19ccf6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,6 +21,7 @@ from synapse.api.errors import SynapseError, Codes from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership from synapse.types import UserID, RoomID, RoomAlias +from synapse.events.utils import serialize_event import json import logging @@ -61,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -124,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -141,8 +142,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet): defer.returnValue((200, data.get_dict()["content"])) @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key): - user = yield self.auth.get_user_by_req(request) + def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -157,7 +158,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet): event_dict["state_key"] = state_key msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_event(event_dict) + yield msg_handler.create_and_send_event( + event_dict, client=client, txn_id=txn_id, + ) defer.returnValue((200, {})) @@ -171,8 +174,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server, with_get=True) @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, event_type, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -182,7 +185,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "content": content, "room_id": room_id, "sender": user.to_string(), - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"event_id": event.event_id})) @@ -199,7 +204,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, event_type) + response = yield self.on_POST(request, room_id, event_type, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -214,8 +219,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_identifier): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_identifier, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -244,7 +249,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet): "room_id": identifier.to_string(), "sender": user.to_string(), "state_key": user.to_string(), - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"room_id": identifier.to_string()})) @@ -258,7 +265,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_identifier) + response = yield self.on_POST(request, room_identifier, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -282,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.room_member_handler members = yield handler.get_room_members_as_pagination_chunk( room_id=room_id, @@ -310,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -334,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -350,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -363,6 +370,10 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomTriggerBackfill(ClientV1RestServlet): PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$") + def __init__(self, hs): + super(RoomTriggerBackfill, self).__init__(hs) + self.clock = hs.get_clock() + @defer.inlineCallbacks def on_GET(self, request, room_id): remote_server = urllib.unquote( @@ -374,7 +385,9 @@ class RoomTriggerBackfill(ClientV1RestServlet): handler = self.handlers.federation_handler events = yield handler.backfill(remote_server, room_id, limit) - res = [self.hs.serialize_event(event) for event in events] + time_now = self.clock.time_msec() + + res = [serialize_event(event, time_now) for event in events] defer.returnValue((200, res)) @@ -388,8 +401,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, membership_action, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -411,7 +424,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), "state_key": state_key, - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {})) @@ -425,7 +440,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, membership_action) + response = yield self.on_POST( + request, room_id, membership_action, txn_id + ) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -437,8 +454,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, event_id, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -449,7 +466,9 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), "redacts": event_id, - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"event_id": event.event_id})) @@ -463,7 +482,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, event_id) + response = yield self.on_POST(request, room_id, event_id, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -476,7 +495,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 822d863ce6..11d08fbced 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py new file mode 100644 index 0000000000..bb740e2803 --- /dev/null +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.http.server import JsonResource + + +class ClientV2AlphaRestResource(JsonResource): + """A resource for version 2 alpha of the matrix client API.""" + + def __init__(self, hs): + JsonResource.__init__(self) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(client_resource, hs): + pass diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py new file mode 100644 index 0000000000..22dc5cb862 --- /dev/null +++ b/synapse/rest/client/v2_alpha/_base.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" + +from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX +import re + +import logging + + +logger = logging.getLogger(__name__) + + +def client_v2_pattern(path_regex): + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + SRE_Pattern + """ + return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index 79ae0e3d74..22e26e3cd5 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index b1718a630b..b939a30e19 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource): @defer.inlineCallbacks def _async_render_POST(self, request): try: - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/server.py b/synapse/server.py index 32013b1a91..f152f0321e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -20,7 +20,6 @@ # Imports required for the default HomeServer() implementation from synapse.federation import initialize_http_replication -from synapse.events.utils import serialize_event from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers @@ -32,6 +31,7 @@ from synapse.util.lockutils import LockManager from synapse.streams.events import EventSources from synapse.api.ratelimiting import Ratelimiter from synapse.crypto.keyring import Keyring +from synapse.push.pusherpool import PusherPool from synapse.events.builder import EventBuilderFactory @@ -70,6 +70,7 @@ class BaseHomeServer(object): 'notifier', 'distributor', 'resource_for_client', + 'resource_for_client_v2_alpha', 'resource_for_federation', 'resource_for_web_client', 'resource_for_content_repo', @@ -78,6 +79,7 @@ class BaseHomeServer(object): 'event_sources', 'ratelimiter', 'keyring', + 'pusherpool', 'event_builder_factory', ] @@ -123,9 +125,6 @@ class BaseHomeServer(object): setattr(BaseHomeServer, "get_%s" % (depname), _get) - def serialize_event(self, e, as_client_event=True): - return serialize_event(self, e, as_client_event) - def get_ip_from_request(self, request): # May be an X-Forwarding-For header depending on config ip_addr = request.getClientIP() @@ -200,3 +199,6 @@ class HomeServer(BaseHomeServer): clock=self.get_clock(), hostname=self.hostname, ) + + def build_pusherpool(self): + return PusherPool(self) diff --git a/synapse/state.py b/synapse/state.py index d9fdfb34be..dff11711a6 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -37,13 +37,15 @@ def _get_state_key_from_event(event): KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) +AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) + + class StateHandler(object): """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.store = hs.get_datastore() - # self.auth = hs.get_auth() self.hs = hs @defer.inlineCallbacks @@ -215,7 +217,7 @@ class StateHandler(object): auth_events = { k: e for k, e in unconflicted_state.items() - if k[0] in (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) + if k[0] in AuthEventTypes } try: @@ -240,10 +242,6 @@ class StateHandler(object): 1. power levels 2. memberships 3. other events. - - :param conflicted_state: - :param auth_events: - :return: """ resolved_state = {} power_key = (EventTypes.PowerLevels, "") @@ -305,4 +303,4 @@ class StateHandler(object): def key_func(e): return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() - return sorted(events, key=key_func) \ No newline at end of file + return sorted(events, key=key_func) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 27d835db79..adcb038020 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -29,6 +29,8 @@ from .stream import StreamStore from .transactions import TransactionStore from .keys import KeyStore from .event_federation import EventFederationStore +from .pusher import PusherStore +from .push_rule import PushRuleStore from .media_repository import MediaRepositoryStore from .rejections import RejectionsStore @@ -61,6 +63,7 @@ SCHEMAS = [ "state", "event_edges", "event_signatures", + "pusher", "media_repository", ] @@ -84,6 +87,8 @@ class DataStore(RoomMemberStore, RoomStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, + PusherStore, + PushRuleStore ): def __init__(self, hs): @@ -387,6 +392,41 @@ class DataStore(RoomMemberStore, RoomStore, defer.returnValue(events) @defer.inlineCallbacks + def get_room_name_and_aliases(self, room_id): + del_sql = ( + "SELECT event_id FROM redactions WHERE redacts = e.event_id " + "LIMIT 1" + ) + + sql = ( + "SELECT e.*, (%(redacted)s) AS redacted FROM events as e " + "INNER JOIN current_state_events as c ON e.event_id = c.event_id " + "INNER JOIN state_events as s ON e.event_id = s.event_id " + "WHERE c.room_id = ? " + ) % { + "redacted": del_sql, + } + + sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')" + sql += " OR s.type = 'm.room.aliases')" + args = (room_id,) + + results = yield self._execute_and_decode(sql, *args) + + events = yield self._parse_events(results) + + name = None + aliases = [] + + for e in events: + if e.type == 'm.room.name': + name = e.content['name'] + elif e.type == 'm.room.aliases': + aliases.extend(e.content['aliases']) + + defer.returnValue((name, aliases)) + + @defer.inlineCallbacks def _get_min_token(self): row = yield self._execute( None, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2075a018b2..1f5e74a16a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -193,6 +193,50 @@ class SQLBaseStore(object): txn.execute(sql, values.values()) return txn.lastrowid + def _simple_upsert(self, table, keyvalues, values): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + Returns: A deferred + """ + return self.runInteraction( + "_simple_upsert", + self._simple_upsert_txn, table, keyvalues, values + ) + + def _simple_upsert_txn(self, txn, table, keyvalues, values): + # Try to update + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join("%s = ?" % (k,) for k in keyvalues) + ) + sqlargs = values.values() + keyvalues.values() + logger.debug( + "[SQL] %s Args=%s", + sql, sqlargs, + ) + + txn.execute(sql, sqlargs) + if txn.rowcount == 0: + # We didn't update and rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + logger.debug( + "[SQL] %s Args=%s", + sql, keyvalues.values(), + ) + txn.execute(sql, allvalues.values()) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False): """Executes a SELECT query on the named table, which is expected to @@ -344,8 +388,8 @@ class SQLBaseStore(object): if updatevalues: update_sql = "UPDATE %s SET %s WHERE %s" % ( table, - ", ".join("%s = ?" % (k) for k in updatevalues), - " AND ".join("%s = ?" % (k) for k in keyvalues) + ", ".join("%s = ?" % (k,) for k in updatevalues), + " AND ".join("%s = ?" % (k,) for k in keyvalues) ) def func(txn): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py new file mode 100644 index 0000000000..27502d2399 --- /dev/null +++ b/synapse/storage/push_rule.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from ._base import SQLBaseStore, Table +from twisted.internet import defer + +import logging +import copy +import json + +logger = logging.getLogger(__name__) + + +class PushRuleStore(SQLBaseStore): + @defer.inlineCallbacks + def get_push_rules_for_user_name(self, user_name): + sql = ( + "SELECT "+",".join(PushRuleTable.fields)+" " + "FROM "+PushRuleTable.table_name+" " + "WHERE user_name = ? " + "ORDER BY priority_class DESC, priority DESC" + ) + rows = yield self._execute(None, sql, user_name) + + dicts = [] + for r in rows: + d = {} + for i, f in enumerate(PushRuleTable.fields): + d[f] = r[i] + dicts.append(d) + + defer.returnValue(dicts) + + @defer.inlineCallbacks + def add_push_rule(self, before, after, **kwargs): + vals = copy.copy(kwargs) + if 'conditions' in vals: + vals['conditions'] = json.dumps(vals['conditions']) + if 'actions' in vals: + vals['actions'] = json.dumps(vals['actions']) + # we could check the rest of the keys are valid column names + # but sqlite will do that anyway so I think it's just pointless. + if 'id' in vals: + del vals['id'] + + if before or after: + ret = yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + before=before, + after=after, + **vals + ) + defer.returnValue(ret) + else: + ret = yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + **vals + ) + defer.returnValue(ret) + + def _add_push_rule_relative_txn(self, txn, user_name, **kwargs): + after = None + relative_to_rule = None + if 'after' in kwargs and kwargs['after']: + after = kwargs['after'] + relative_to_rule = after + if 'before' in kwargs and kwargs['before']: + relative_to_rule = kwargs['before'] + + # get the priority of the rule we're inserting after/before + sql = ( + "SELECT priority_class, priority FROM ? " + "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,) + ) + txn.execute(sql, (user_name, relative_to_rule)) + res = txn.fetchall() + if not res: + raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule)) + priority_class, base_rule_priority = res[0] + + if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: + raise InconsistentRuleException( + "Given priority class does not match class of relative rule" + ) + + new_rule = copy.copy(kwargs) + if 'before' in new_rule: + del new_rule['before'] + if 'after' in new_rule: + del new_rule['after'] + new_rule['priority_class'] = priority_class + new_rule['user_name'] = user_name + + # check if the priority before/after is free + new_rule_priority = base_rule_priority + if after: + new_rule_priority -= 1 + else: + new_rule_priority += 1 + + new_rule['priority'] = new_rule_priority + + sql = ( + "SELECT COUNT(*) FROM " + PushRuleTable.table_name + + " WHERE user_name = ? AND priority_class = ? AND priority = ?" + ) + txn.execute(sql, (user_name, priority_class, new_rule_priority)) + res = txn.fetchall() + num_conflicting = res[0][0] + + # if there are conflicting rules, bump everything + if num_conflicting: + sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority " + if after: + sql += "-1" + else: + sql += "+1" + sql += " WHERE user_name = ? AND priority_class = ? AND priority " + if after: + sql += "<= ?" + else: + sql += ">= ?" + + txn.execute(sql, (user_name, priority_class, new_rule_priority)) + + # now insert the new rule + sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql += ",".join(new_rule.keys())+") VALUES (" + sql += ", ".join(["?" for _ in new_rule.keys()])+")" + + txn.execute(sql, new_rule.values()) + + def _add_push_rule_highest_priority_txn(self, txn, user_name, + priority_class, **kwargs): + # find the highest priority rule in that class + sql = ( + "SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name + + " WHERE user_name = ? and priority_class = ?" + ) + txn.execute(sql, (user_name, priority_class)) + res = txn.fetchall() + (how_many, highest_prio) = res[0] + + new_prio = 0 + if how_many > 0: + new_prio = highest_prio + 1 + + # and insert the new rule + new_rule = copy.copy(kwargs) + if 'id' in new_rule: + del new_rule['id'] + new_rule['user_name'] = user_name + new_rule['priority_class'] = priority_class + new_rule['priority'] = new_prio + + sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql += ",".join(new_rule.keys())+") VALUES (" + sql += ", ".join(["?" for _ in new_rule.keys()])+")" + + txn.execute(sql, new_rule.values()) + + @defer.inlineCallbacks + def delete_push_rule(self, user_name, rule_id, **kwargs): + """ + Delete a push rule. Args specify the row to be deleted and can be + any of the columns in the push_rule table, but below are the + standard ones + + Args: + user_name (str): The matrix ID of the push rule owner + rule_id (str): The rule_id of the rule to be deleted + """ + yield self._simple_delete_one(PushRuleTable.table_name, kwargs) + + +class RuleNotFoundException(Exception): + pass + + +class InconsistentRuleException(Exception): + pass + + +class PushRuleTable(Table): + table_name = "push_rules" + + fields = [ + "id", + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ] + + EntryType = collections.namedtuple("PushRuleEntry", fields) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py new file mode 100644 index 0000000000..f253c9e2c3 --- /dev/null +++ b/synapse/storage/pusher.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from ._base import SQLBaseStore, Table +from twisted.internet import defer + +from synapse.api.errors import StoreError + +import logging + +logger = logging.getLogger(__name__) + + +class PusherStore(SQLBaseStore): + @defer.inlineCallbacks + def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): + sql = ( + "SELECT id, user_name, kind, instance_handle, app_id," + "app_display_name, device_display_name, pushkey, ts, data, " + "last_token, last_success, failing_since " + "FROM pushers " + "WHERE app_id = ? AND pushkey = ?" + ) + + rows = yield self._execute( + None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1] + ) + + ret = [ + { + "id": r[0], + "user_name": r[1], + "kind": r[2], + "instance_handle": r[3], + "app_id": r[4], + "app_display_name": r[5], + "device_display_name": r[6], + "pushkey": r[7], + "pushkey_ts": r[8], + "data": r[9], + "last_token": r[10], + "last_success": r[11], + "failing_since": r[12] + } + for r in rows + ] + + defer.returnValue(ret[0]) + + @defer.inlineCallbacks + def get_all_pushers(self): + sql = ( + "SELECT id, user_name, kind, instance_handle, app_id," + "app_display_name, device_display_name, pushkey, ts, data, " + "last_token, last_success, failing_since " + "FROM pushers" + ) + + rows = yield self._execute(None, sql) + + ret = [ + { + "id": r[0], + "user_name": r[1], + "kind": r[2], + "instance_handle": r[3], + "app_id": r[4], + "app_display_name": r[5], + "device_display_name": r[6], + "pushkey": r[7], + "pushkey_ts": r[8], + "data": r[9], + "last_token": r[10], + "last_success": r[11], + "failing_since": r[12] + } + for r in rows + ] + + defer.returnValue(ret) + + @defer.inlineCallbacks + def add_pusher(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, pushkey_ts, lang, data): + try: + yield self._simple_upsert( + PushersTable.table_name, + dict( + app_id=app_id, + pushkey=pushkey, + ), + dict( + user_name=user_name, + kind=kind, + instance_handle=instance_handle, + app_display_name=app_display_name, + device_display_name=device_display_name, + ts=pushkey_ts, + lang=lang, + data=data + )) + except Exception as e: + logger.error("create_pusher with failed: %s", e) + raise StoreError(500, "Problem creating pusher.") + + @defer.inlineCallbacks + def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): + yield self._simple_delete_one( + PushersTable.table_name, + dict(app_id=app_id, pushkey=pushkey) + ) + + @defer.inlineCallbacks + def update_pusher_last_token(self, user_name, pushkey, last_token): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'last_token': last_token} + ) + + @defer.inlineCallbacks + def update_pusher_last_token_and_success(self, user_name, pushkey, + last_token, last_success): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'last_token': last_token, 'last_success': last_success} + ) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, user_name, pushkey, failing_since): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'failing_since': failing_since} + ) + + +class PushersTable(Table): + table_name = "pushers" + + fields = [ + "id", + "user_name", + "kind", + "instance_handle", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "pushkey_ts", + "data", + "last_token", + "last_success", + "failing_since" + ] + + EntryType = collections.namedtuple("PusherEntry", fields) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 75dffa4db2..029b07cc66 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin, access_tokens.device_id" + "SELECT users.name, users.admin," + " access_tokens.device_id, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.id = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql index bd2a8b1bb5..a6867cba62 100644 --- a/synapse/storage/schema/delta/v12.sql +++ b/synapse/storage/schema/delta/v12.sql @@ -19,3 +19,36 @@ CREATE TABLE IF NOT EXISTS rejections( last_check TEXT NOT NULL, CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE ); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + instance_handle varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); diff --git a/synapse/storage/schema/pusher.sql b/synapse/storage/schema/pusher.sql new file mode 100644 index 0000000000..8c4dfd5c1b --- /dev/null +++ b/synapse/storage/schema/pusher.sql @@ -0,0 +1,46 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + instance_handle varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 8ac2adab05..062ca06fb3 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -82,10 +82,10 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")): def parse(cls, string): try: if string[0] == 's': - return cls(None, int(string[1:])) + return cls(topological=None, stream=int(string[1:])) if string[0] == 't': parts = string[1:].split('-', 1) - return cls(int(parts[1]), int(parts[0])) + return cls(topological=int(parts[0]), stream=int(parts[1])) except: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -94,7 +94,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")): def parse_stream_token(cls, string): try: if string[0] == 's': - return cls(None, int(string[1:])) + return cls(topological=None, stream=int(string[1:])) except: pass raise SynapseError(400, "Invalid token %r" % (string,)) diff --git a/synapse/types.py b/synapse/types.py index faac729ff2..f6a1b0bbcf 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -119,3 +119,6 @@ class StreamToken( d = self._asdict() d[key] = new_value return StreamToken(**d) + + +ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 65d5cc4916..f849120a3e 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase): "user": UserID.from_string(myid), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase): "user": UserID.from_string(myid), "admin": False, "device_id": None, + "token_id": 1, } hs.handlers.room_member_handler = Mock( @@ -282,7 +284,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): hs.get_clock().time_msec.return_value = 1000000 def _get_user_by_req(req=None): - return UserID.from_string(myid) + return (UserID.from_string(myid), "") hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 39cd68d829..6a2085276a 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -58,7 +58,7 @@ class ProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None): - return UserID.from_string(myid) + return (UserID.from_string(myid), "") hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 76ed550b75..81ead10e76 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index c89b37d004..c5d5b06da3 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py new file mode 100644 index 0000000000..f59745e13c --- /dev/null +++ b/tests/rest/client/v2_alpha/__init__.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests import unittest + +from mock import Mock + +from ....utils import MockHttpResource, MockKey + +from synapse.server import HomeServer +from synapse.types import UserID + + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class V2AlphaRestTestCase(unittest.TestCase): + # Consumer must define + # USER_ID = <some string> + # TO_REGISTER = [<list of REST servlets to register>] + + def setUp(self): + self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) + + mock_config = Mock() + mock_config.signing_key = [MockKey()] + + hs = HomeServer("test", + db_pool=None, + datastore=Mock(spec=[ + "insert_client_ip", + ]), + http_client=None, + resource_for_client=self.mock_resource, + resource_for_federation=self.mock_resource, + config=mock_config, + ) + + def _get_user_by_token(token=None): + return { + "user": UserID.from_string(self.USER_ID), + "admin": False, + "device_id": None, + } + hs.get_auth().get_user_by_token = _get_user_by_token + + for r in self.TO_REGISTER: + r.register_servlets(hs, self.mock_resource) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 84bfde7568..6f8bea2f61 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase): ) self.assertEquals( - {"admin": 0, "device_id": None, "name": self.user_id}, + {"admin": 0, + "device_id": None, + "name": self.user_id, + "token_id": 1}, (yield self.store.get_user_by_token(self.tokens[0])) ) @@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) self.assertEquals( - {"admin": 0, "device_id": None, "name": self.user_id}, + {"admin": 0, + "device_id": None, + "name": self.user_id, + "token_id": 2}, (yield self.store.get_user_by_token(self.tokens[1])) ) |