diff --git a/synapse/__init__.py b/synapse/__init__.py
index f68a15bb85..7ff37edf2c 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.11.0-rc1"
+__version__ = "0.11.0-r2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 3e891a6193..4fdc779b4b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -207,6 +207,13 @@ class Auth(object):
user_id, room_id
))
+ if membership == Membership.LEAVE:
+ forgot = yield self.store.did_forget(user_id, room_id)
+ if forgot:
+ raise AuthError(403, "User %s not in room %s" % (
+ user_id, room_id
+ ))
+
defer.returnValue(member)
@defer.inlineCallbacks
@@ -587,7 +594,7 @@ class Auth(object):
def _get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
- self._validate_macaroon(macaroon)
+ self.validate_macaroon(macaroon, "access", False)
user_prefix = "user_id = "
user = None
@@ -635,13 +642,27 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
- def _validate_macaroon(self, macaroon):
+ def validate_macaroon(self, macaroon, type_string, verify_expiry):
+ """
+ validate that a Macaroon is understood by and was signed by this server.
+
+ Args:
+ macaroon(pymacaroons.Macaroon): The macaroon to validate
+ type_string(str): The kind of token this is (e.g. "access", "refresh")
+ verify_expiry(bool): Whether to verify whether the macaroon has expired.
+ This should really always be True, but no clients currently implement
+ token refresh, so we can't enforce expiry yet.
+ """
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
- v.satisfy_exact("type = access")
+ v.satisfy_exact("type = " + type_string)
v.satisfy_general(lambda c: c.startswith("user_id = "))
- v.satisfy_general(self._verify_expiry)
v.satisfy_exact("guest = true")
+ if verify_expiry:
+ v.satisfy_general(self._verify_expiry)
+ else:
+ v.satisfy_general(lambda c: c.startswith("time < "))
+
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier()
@@ -652,9 +673,6 @@ class Auth(object):
prefix = "time < "
if not caveat.startswith(prefix):
return False
- # TODO(daniel): Enable expiry check when clients actually know how to
- # refresh tokens. (And remember to enable the tests)
- return True
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index aaa2433cae..18f2ec3ae8 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -54,7 +54,7 @@ class Filtering(object):
]
room_level_definitions = [
- "state", "timeline", "ephemeral", "private_user_data"
+ "state", "timeline", "ephemeral", "account_data"
]
for key in top_level_definitions:
@@ -131,8 +131,8 @@ class FilterCollection(object):
self.filter_json.get("room", {}).get("ephemeral", {})
)
- self.room_private_user_data = Filter(
- self.filter_json.get("room", {}).get("private_user_data", {})
+ self.room_account_data = Filter(
+ self.filter_json.get("room", {}).get("account_data", {})
)
self.presence_filter = Filter(
@@ -160,8 +160,8 @@ class FilterCollection(object):
def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events)
- def filter_room_private_user_data(self, events):
- return self.room_private_user_data.filter(events)
+ def filter_room_account_data(self, events):
+ return self.room_account_data.filter(events)
class Filter(object):
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index c18e0bdbb8..d0c9972445 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -25,18 +25,29 @@ class ConfigError(Exception):
pass
-class Config(object):
+# We split these messages out to allow packages to override with package
+# specific instructions.
+MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
+Please opt in or out of reporting anonymized homeserver usage statistics, by
+setting the `report_stats` key in your config file to either True or False.
+"""
+
+MISSING_REPORT_STATS_SPIEL = """\
+We would really appreciate it if you could help our project out by reporting
+anonymized usage statistics from your homeserver. Only very basic aggregate
+data (e.g. number of users) will be reported, but it helps us to track the
+growth of the Matrix community, and helps us to make Matrix a success, as well
+as to convince other networks that they should peer with us.
+
+Thank you.
+"""
+
+MISSING_SERVER_NAME = """\
+Missing mandatory `server_name` config option.
+"""
- stats_reporting_begging_spiel = (
- "We would really appreciate it if you could help our project out by"
- " reporting anonymized usage statistics from your homeserver. Only very"
- " basic aggregate data (e.g. number of users) will be reported, but it"
- " helps us to track the growth of the Matrix community, and helps us to"
- " make Matrix a success, as well as to convince other networks that they"
- " should peer with us."
- "\nThank you."
- )
+class Config(object):
@staticmethod
def parse_size(value):
if isinstance(value, int) or isinstance(value, long):
@@ -215,7 +226,7 @@ class Config(object):
if config_args.report_stats is None:
config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" +
- cls.stats_reporting_begging_spiel
+ MISSING_REPORT_STATS_SPIEL
)
if not config_files:
config_parser.error(
@@ -290,6 +301,10 @@ class Config(object):
yaml_config = cls.read_config_file(config_file)
specified_config.update(yaml_config)
+ if "server_name" not in specified_config:
+ sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
+ sys.exit(1)
+
server_name = specified_config["server_name"]
_, config = obj.generate_config(
config_dir_path=config_dir_path,
@@ -299,11 +314,8 @@ class Config(object):
config.update(specified_config)
if "report_stats" not in config:
sys.stderr.write(
- "Please opt in or out of reporting anonymized homeserver usage "
- "statistics, by setting the report_stats key in your config file "
- " ( " + config_path + " ) " +
- "to either True or False.\n\n" +
- Config.stats_reporting_begging_spiel + "\n")
+ "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
+ MISSING_REPORT_STATS_SPIEL + "\n")
sys.exit(1)
if generate_keys:
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index a337ae6ca0..326e405841 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -27,10 +27,12 @@ class CasConfig(Config):
if cas_config:
self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"]
+ self.cas_service_url = cas_config["service_url"]
self.cas_required_attributes = cas_config.get("required_attributes", {})
else:
self.cas_enabled = False
self.cas_server_url = None
+ self.cas_service_url = None
self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs):
@@ -39,6 +41,7 @@ class CasConfig(Config):
#cas_config:
# enabled: true
# server_url: "https://cas-server.com"
+ # service_url: "https://homesever.domain.com:8448"
# #required_attributes:
# # name: value
"""
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 8b6a59866f..bc5bb5cdb1 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -381,28 +381,24 @@ class Keyring(object):
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
- limiter = yield get_retry_limiter(
- perspective_name, self.clock, self.store
- )
-
- with limiter:
- # TODO(mark): Set the minimum_valid_until_ts to that needed by
- # the events being validated or the current time if validating
- # an incoming request.
- query_response = yield self.client.post_json(
- destination=perspective_name,
- path=b"/_matrix/key/v2/query",
- data={
- u"server_keys": {
- server_name: {
- key_id: {
- u"minimum_valid_until_ts": 0
- } for key_id in key_ids
- }
- for server_name, key_ids in server_names_and_key_ids
+ # TODO(mark): Set the minimum_valid_until_ts to that needed by
+ # the events being validated or the current time if validating
+ # an incoming request.
+ query_response = yield self.client.post_json(
+ destination=perspective_name,
+ path=b"/_matrix/key/v2/query",
+ data={
+ u"server_keys": {
+ server_name: {
+ key_id: {
+ u"minimum_valid_until_ts": 0
+ } for key_id in key_ids
}
- },
- )
+ for server_name, key_ids in server_names_and_key_ids
+ }
+ },
+ long_retries=True,
+ )
keys = {}
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9989b76591..44cc1ef132 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -129,10 +129,9 @@ def format_event_for_client_v2(d):
return d
-def format_event_for_client_v2_without_event_id(d):
+def format_event_for_client_v2_without_room_id(d):
d = format_event_for_client_v2(d)
d.pop("room_id", None)
- d.pop("event_id", None)
return d
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d4f586fae7..c6a8c1249a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -401,6 +401,12 @@ class FederationClient(FederationBase):
pdu_dict["content"].update(content)
+ # The protoevent received over the JSON wire may not have all
+ # the required fields. Lets just gloss over that because
+ # there's some we never care about
+ if "prev_state" not in pdu_dict:
+ pdu_dict["prev_state"] = []
+
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 3d59e1c650..0e0cb7ebc6 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -136,6 +136,7 @@ class TransportLayerClient(object):
path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=json_data,
json_data_callback=json_data_callback,
+ long_retries=True,
)
logger.debug(
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 6519f183df..5fd20285d2 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -92,7 +92,15 @@ class BaseHandler(object):
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
- membership = membership_event.membership
+ was_forgotten_at_event = yield self.store.was_forgotten_at(
+ membership_event.state_key,
+ membership_event.room_id,
+ membership_event.event_id
+ )
+ if was_forgotten_at_event:
+ membership = None
+ else:
+ membership = membership_event.membership
else:
membership = None
diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/account_data.py
index 1abe45ed7b..1d35d3b7dc 100644
--- a/synapse/handlers/private_user_data.py
+++ b/synapse/handlers/account_data.py
@@ -16,19 +16,19 @@
from twisted.internet import defer
-class PrivateUserDataEventSource(object):
+class AccountDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
def get_current_key(self, direction='f'):
- return self.store.get_max_private_user_data_stream_id()
+ return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
- current_stream_id = yield self.store.get_max_private_user_data_stream_id()
+ current_stream_id = yield self.store.get_max_account_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = []
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1b11dbdffd..e64b67cdfd 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
-from synapse.api.errors import LoginError, Codes
+from synapse.api.errors import AuthError, LoginError, Codes
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
@@ -46,6 +46,7 @@ class AuthHandler(BaseHandler):
}
self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {}
+ self.INVALID_TOKEN_HTTP_STATUS = 401
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -297,10 +298,11 @@ class AuthHandler(BaseHandler):
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
- def login_with_cas_user_id(self, user_id):
+ def get_login_tuple_for_user_id(self, user_id):
"""
- Authenticates the user with the given user ID,
- intended to have been captured from a CAS response
+ Gets login tuple for the user with the given user ID.
+ The user is assumed to have been authenticated by some other
+ machanism (e.g. CAS)
Args:
user_id (str): User ID
@@ -393,6 +395,23 @@ class AuthHandler(BaseHandler):
))
return m.serialize()
+ def generate_short_term_login_token(self, user_id):
+ macaroon = self._generate_base_macaroon(user_id)
+ macaroon.add_first_party_caveat("type = login")
+ now = self.hs.get_clock().time_msec()
+ expiry = now + (2 * 60 * 1000)
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ return macaroon.serialize()
+
+ def validate_short_term_login_token_and_get_user_id(self, login_token):
+ try:
+ macaroon = pymacaroons.Macaroon.deserialize(login_token)
+ auth_api = self.hs.get_auth()
+ auth_api.validate_macaroon(macaroon, "login", True)
+ return self._get_user_from_macaroon(macaroon)
+ except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
+
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
@@ -402,6 +421,16 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
+ def _get_user_from_macaroon(self, macaroon):
+ user_prefix = "user_id = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(user_prefix):
+ return caveat.caveat_id[len(user_prefix):]
+ raise AuthError(
+ self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = self.hash(newpassword)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a92409c6a2..64c57375f7 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -436,14 +436,14 @@ class MessageHandler(BaseHandler):
for c in current_state.values()
]
- private_user_data = []
+ account_data = []
tags = tags_by_room.get(event.room_id)
if tags:
- private_user_data.append({
+ account_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
- d["private_user_data"] = private_user_data
+ d["account_data"] = account_data
except:
logger.exception("Failed to get snapshot")
@@ -498,14 +498,14 @@ class MessageHandler(BaseHandler):
user_id, room_id, pagin_config, membership, member_event_id, is_guest
)
- private_user_data = []
+ account_data = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
- private_user_data.append({
+ account_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
- result["private_user_data"] = private_user_data
+ result["account_data"] = account_data
defer.returnValue(result)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3f04752581..023b4001b8 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -743,6 +743,9 @@ class RoomMemberHandler(BaseHandler):
)
defer.returnValue((token, public_key, key_validity_url, display_name))
+ def forget(self, user, room_id):
+ self.store.forget(user.to_string(), room_id)
+
class RoomListHandler(BaseHandler):
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index b7545c111f..50688e51a8 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -17,13 +17,14 @@ from twisted.internet import defer
from ._base import BaseHandler
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, EventTypes
from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
from unpaddedbase64 import decode_base64, encode_base64
+import itertools
import logging
@@ -79,6 +80,9 @@ class SearchHandler(BaseHandler):
# What to order results by (impacts whether pagination can be doen)
order_by = room_cat.get("order_by", "rank")
+ # Return the current state of the rooms?
+ include_state = room_cat.get("include_state", False)
+
# Include context around each event?
event_context = room_cat.get(
"event_context", None
@@ -96,6 +100,10 @@ class SearchHandler(BaseHandler):
after_limit = int(event_context.get(
"after_limit", 5
))
+
+ # Return the historic display name and avatar for the senders
+ # of the events?
+ include_profile = bool(event_context.get("include_profile", False))
except KeyError:
raise SynapseError(400, "Invalid search query")
@@ -269,6 +277,33 @@ class SearchHandler(BaseHandler):
"room_key", res["end"]
).to_string()
+ if include_profile:
+ senders = set(
+ ev.sender
+ for ev in itertools.chain(
+ res["events_before"], [event], res["events_after"]
+ )
+ )
+
+ if res["events_after"]:
+ last_event_id = res["events_after"][-1].event_id
+ else:
+ last_event_id = event.event_id
+
+ state = yield self.store.get_state_for_event(
+ last_event_id,
+ types=[(EventTypes.Member, sender) for sender in senders]
+ )
+
+ res["profile_info"] = {
+ s.state_key: {
+ "displayname": s.content.get("displayname", None),
+ "avatar_url": s.content.get("avatar_url", None),
+ }
+ for s in state.values()
+ if s.type == EventTypes.Member and s.state_key in senders
+ }
+
contexts[event.event_id] = res
else:
contexts = {}
@@ -287,6 +322,18 @@ class SearchHandler(BaseHandler):
for e in context["events_after"]
]
+ state_results = {}
+ if include_state:
+ rooms = set(e.room_id for e in allowed_events)
+ for room_id in rooms:
+ state = yield self.state_handler.get_current_state(room_id)
+ state_results[room_id] = state.values()
+
+ state_results.values()
+
+ # We're now about to serialize the events. We should not make any
+ # blocking calls after this. Otherwise the 'age' will be wrong
+
results = {
e.event_id: {
"rank": rank_map[e.event_id],
@@ -303,6 +350,12 @@ class SearchHandler(BaseHandler):
"count": len(results)
}
+ if state_results:
+ rooms_cat_res["state"] = {
+ room_id: [serialize_event(e, time_now) for e in state]
+ for room_id, state in state_results.items()
+ }
+
if room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6dc9d0fb92..877328b29e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -51,7 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent]
"ephemeral",
- "private_user_data",
+ "account_data",
])):
__slots__ = []
@@ -63,7 +63,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
self.timeline
or self.state
or self.ephemeral
- or self.private_user_data
+ or self.account_data
)
@@ -71,7 +71,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", # str
"timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent]
- "private_user_data",
+ "account_data",
])):
__slots__ = []
@@ -82,7 +82,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
return bool(
self.timeline
or self.state
- or self.private_user_data
+ or self.account_data
)
@@ -261,20 +261,20 @@ class SyncHandler(BaseHandler):
timeline=batch,
state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []),
- private_user_data=self.private_user_data_for_room(
+ account_data=self.account_data_for_room(
room_id, tags_by_room
),
))
- def private_user_data_for_room(self, room_id, tags_by_room):
- private_user_data = []
+ def account_data_for_room(self, room_id, tags_by_room):
+ account_data = []
tags = tags_by_room.get(room_id)
if tags is not None:
- private_user_data.append({
+ account_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
- return private_user_data
+ return account_data
@defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None):
@@ -357,7 +357,7 @@ class SyncHandler(BaseHandler):
room_id=room_id,
timeline=batch,
state=leave_state,
- private_user_data=self.private_user_data_for_room(
+ account_data=self.account_data_for_room(
room_id, tags_by_room
),
))
@@ -412,7 +412,7 @@ class SyncHandler(BaseHandler):
tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(),
- since_token.private_user_data_key,
+ since_token.account_data_key,
)
joined = []
@@ -468,7 +468,7 @@ class SyncHandler(BaseHandler):
),
state=state,
ephemeral=ephemeral_by_room.get(room_id, []),
- private_user_data=self.private_user_data_for_room(
+ account_data=self.account_data_for_room(
room_id, tags_by_room
),
)
@@ -605,7 +605,7 @@ class SyncHandler(BaseHandler):
timeline=batch,
state=state,
ephemeral=ephemeral_by_room.get(room_id, []),
- private_user_data=self.private_user_data_for_room(
+ account_data=self.account_data_for_room(
room_id, tags_by_room
),
)
@@ -653,7 +653,7 @@ class SyncHandler(BaseHandler):
room_id=leave_event.room_id,
timeline=batch,
state=state_events_delta,
- private_user_data=self.private_user_data_for_room(
+ account_data=self.account_data_for_room(
leave_event.room_id, tags_by_room
),
)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 6e53538a52..b7b7c2cce8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -56,7 +56,8 @@ incoming_responses_counter = metrics.register_counter(
)
-MAX_RETRIES = 4
+MAX_LONG_RETRIES = 10
+MAX_SHORT_RETRIES = 3
class MatrixFederationEndpointFactory(object):
@@ -103,7 +104,7 @@ class MatrixFederationHttpClient(object):
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True,
- timeout=None):
+ timeout=None, long_retries=False):
""" Creates and sends a request to the given url
"""
headers_dict[b"User-Agent"] = [self.version_string]
@@ -123,7 +124,10 @@ class MatrixFederationHttpClient(object):
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
- retries_left = MAX_RETRIES
+ if long_retries:
+ retries_left = MAX_LONG_RETRIES
+ else:
+ retries_left = MAX_SHORT_RETRIES
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
@@ -184,8 +188,15 @@ class MatrixFederationHttpClient(object):
)
if retries_left and not timeout:
- delay = 5 ** (MAX_RETRIES + 1 - retries_left)
- delay *= random.uniform(0.8, 1.4)
+ if long_retries:
+ delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
+ delay = min(delay, 60)
+ delay *= random.uniform(0.8, 1.4)
+ else:
+ delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
+ delay = min(delay, 2)
+ delay *= random.uniform(0.8, 1.4)
+
yield sleep(delay)
retries_left -= 1
else:
@@ -236,7 +247,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, data={}, json_data_callback=None):
+ def put_json(self, destination, path, data={}, json_data_callback=None,
+ long_retries=False):
""" Sends the specifed json data using PUT
Args:
@@ -247,6 +259,8 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
+ long_retries (bool): A boolean that indicates whether we should
+ retry for a short or long time.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -272,6 +286,7 @@ class MatrixFederationHttpClient(object):
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
+ long_retries=long_retries,
)
if 200 <= response.code < 300:
@@ -287,7 +302,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def post_json(self, destination, path, data={}):
+ def post_json(self, destination, path, data={}, long_retries=True):
""" Sends the specifed json data using POST
Args:
@@ -296,6 +311,8 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
+ long_retries (bool): A boolean that indicates whether we should
+ retry for a short or long time.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -315,6 +332,7 @@ class MatrixFederationHttpClient(object):
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
+ long_retries=True,
)
if 200 <= response.code < 300:
@@ -490,6 +508,9 @@ class _JsonProducer(object):
def stopProducing(self):
pass
+ def resumeProducing(self):
+ pass
+
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 4ea06c1434..720d6358e7 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -22,6 +22,7 @@ from base import ClientV1RestServlet, client_path_pattern
import simplejson as json
import urllib
+import urlparse
import logging
from saml2 import BINDING_HTTP_POST
@@ -39,6 +40,7 @@ class LoginRestServlet(ClientV1RestServlet):
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
+ TOKEN_TYPE = "m.login.token"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
@@ -56,8 +58,18 @@ class LoginRestServlet(ClientV1RestServlet):
flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE})
+
+ # While its valid for us to advertise this login type generally,
+ # synapse currently only gives out these tokens as part of the
+ # CAS login flow.
+ # Generally we don't want to advertise login flows that clients
+ # don't know how to implement, since they (currently) will always
+ # fall back to the fallback API if they don't understand one of the
+ # login flow types returned.
+ flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE})
+
return (200, {"flows": flows})
def on_OPTIONS(self, request):
@@ -83,6 +95,7 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
+ # TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for
@@ -96,6 +109,9 @@ class LoginRestServlet(ClientV1RestServlet):
body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
+ elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+ result = yield self.do_token_login(login_submission)
+ defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
@@ -132,6 +148,26 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
@defer.inlineCallbacks
+ def do_token_login(self, login_submission):
+ token = login_submission['token']
+ auth_handler = self.handlers.auth_handler
+ user_id = (
+ yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ )
+ user_id, access_token, refresh_token = (
+ yield auth_handler.get_login_tuple_for_user_id(user_id)
+ )
+ result = {
+ "user_id": user_id, # may have changed
+ "access_token": access_token,
+ "refresh_token": refresh_token,
+ "home_server": self.hs.hostname,
+ }
+
+ defer.returnValue((200, result))
+
+ # TODO Delete this after all CAS clients switch to token login instead
+ @defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
@@ -152,7 +188,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
- yield auth_handler.login_with_cas_user_id(user_id)
+ yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
@@ -173,6 +209,7 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
+ # TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
@@ -243,6 +280,7 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"}))
+# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas")
@@ -254,6 +292,115 @@ class CasRestServlet(ClientV1RestServlet):
return (200, {"serverUrl": self.cas_server_url})
+class CasRedirectServlet(ClientV1RestServlet):
+ PATTERN = client_path_pattern("/login/cas/redirect")
+
+ def __init__(self, hs):
+ super(CasRedirectServlet, self).__init__(hs)
+ self.cas_server_url = hs.config.cas_server_url
+ self.cas_service_url = hs.config.cas_service_url
+
+ def on_GET(self, request):
+ args = request.args
+ if "redirectUrl" not in args:
+ return (400, "Redirect URL not specified for CAS auth")
+ client_redirect_url_param = urllib.urlencode({
+ "redirectUrl": args["redirectUrl"][0]
+ })
+ hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
+ service_param = urllib.urlencode({
+ "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
+ })
+ request.redirect("%s?%s" % (self.cas_server_url, service_param))
+ request.finish()
+
+
+class CasTicketServlet(ClientV1RestServlet):
+ PATTERN = client_path_pattern("/login/cas/ticket")
+
+ def __init__(self, hs):
+ super(CasTicketServlet, self).__init__(hs)
+ self.cas_server_url = hs.config.cas_server_url
+ self.cas_service_url = hs.config.cas_service_url
+ self.cas_required_attributes = hs.config.cas_required_attributes
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ client_redirect_url = request.args["redirectUrl"][0]
+ http_client = self.hs.get_simple_http_client()
+ uri = self.cas_server_url + "/proxyValidate"
+ args = {
+ "ticket": request.args["ticket"],
+ "service": self.cas_service_url
+ }
+ body = yield http_client.get_raw(uri, args)
+ result = yield self.handle_cas_response(request, body, client_redirect_url)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def handle_cas_response(self, request, cas_response_body, client_redirect_url):
+ user, attributes = self.parse_cas_response(cas_response_body)
+
+ for required_attribute, required_value in self.cas_required_attributes.items():
+ # If required attribute was not in CAS Response - Forbidden
+ if required_attribute not in attributes:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+ # Also need to check value
+ if required_value is not None:
+ actual_value = attributes[required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if required_value != actual_value:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+ user_id = UserID.create(user, self.hs.hostname).to_string()
+ auth_handler = self.handlers.auth_handler
+ user_exists = yield auth_handler.does_user_exist(user_id)
+ if not user_exists:
+ user_id, _ = (
+ yield self.handlers.registration_handler.register(localpart=user)
+ )
+
+ login_token = auth_handler.generate_short_term_login_token(user_id)
+ redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
+ login_token)
+ request.redirect(redirect_url)
+ request.finish()
+
+ def add_login_token_to_redirect_url(self, url, token):
+ url_parts = list(urlparse.urlparse(url))
+ query = dict(urlparse.parse_qsl(url_parts[4]))
+ query.update({"loginToken": token})
+ url_parts[4] = urllib.urlencode(query)
+ return urlparse.urlunparse(url_parts)
+
+ def parse_cas_response(self, cas_response_body):
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
+ if not root[0].tag.endswith("authenticationSuccess"):
+ raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ attributes = {}
+ for attribute in child:
+ # ElementTree library expands the namespace in attribute tags
+ # to the full URL of the namespace.
+ # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
+ # We don't care about namespace here and it will always be encased in
+ # curly braces, so we remove them.
+ if "}" in attribute.tag:
+ attributes[attribute.tag.split("}")[1]] = attribute.text
+ else:
+ attributes[attribute.tag] = attribute.text
+ if user is None or attributes is None:
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
+
+ return (user, attributes)
+
+
def _parse_json(request):
try:
content = json.loads(request.content.read())
@@ -269,5 +416,7 @@ def register_servlets(hs, http_server):
if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
+ CasRedirectServlet(hs).register(http_server)
+ CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 139dac1cc3..6952d269ec 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -448,7 +448,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|kick)")
+ "(?P<membership_action>join|invite|leave|ban|kick|forget)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
@@ -458,6 +458,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
allow_guest=True
)
+ effective_membership_action = membership_action
+
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
raise AuthError(403, "Guest access not allowed")
@@ -488,11 +490,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
UserID.from_string(state_key)
if membership_action == "kick":
- membership_action = "leave"
+ effective_membership_action = "leave"
+ elif membership_action == "forget":
+ effective_membership_action = "leave"
msg_handler = self.handlers.message_handler
- content = {"membership": unicode(membership_action)}
+ content = {"membership": unicode(effective_membership_action)}
if is_guest:
content["kind"] = "guest"
@@ -509,6 +513,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
is_guest=is_guest,
)
+ if membership_action == "forget":
+ self.handlers.room_member_handler.forget(user, room_id)
+
defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index efd8281558..775f49885b 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -22,7 +22,7 @@ from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import (
- serialize_event, format_event_for_client_v2_without_event_id,
+ serialize_event, format_event_for_client_v2_without_room_id,
)
from synapse.api.filtering import FilterCollection
from ._base import client_v2_pattern
@@ -148,9 +148,9 @@ class SyncRestServlet(RestServlet):
sync_result.presence, filter, time_now
),
"rooms": {
- "joined": joined,
- "invited": invited,
- "archived": archived,
+ "join": joined,
+ "invite": invited,
+ "leave": archived,
},
"next_batch": sync_result.next_batch.to_string(),
}
@@ -207,7 +207,7 @@ class SyncRestServlet(RestServlet):
for room in rooms:
invite = serialize_event(
room.invite, time_now, token_id=token_id,
- event_format=format_event_for_client_v2_without_event_id,
+ event_format=format_event_for_client_v2_without_room_id,
)
invited_state = invite.get("unsigned", {}).pop("invite_room_state", [])
invited_state.append(invite)
@@ -256,7 +256,13 @@ class SyncRestServlet(RestServlet):
:return: the room, encoded in our response format
:rtype: dict[str, object]
"""
- event_map = {}
+ def serialize(event):
+ # TODO(mjark): Respect formatting requirements in the filter.
+ return serialize_event(
+ event, time_now, token_id=token_id,
+ event_format=format_event_for_client_v2_without_room_id,
+ )
+
state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events)
@@ -264,37 +270,22 @@ class SyncRestServlet(RestServlet):
state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values())
- state_event_ids = []
- for event in state_events:
- # TODO(mjark): Respect formatting requirements in the filter.
- event_map[event.event_id] = serialize_event(
- event, time_now, token_id=token_id,
- event_format=format_event_for_client_v2_without_event_id,
- )
- state_event_ids.append(event.event_id)
- timeline_event_ids = []
- for event in timeline_events:
- # TODO(mjark): Respect formatting requirements in the filter.
- event_map[event.event_id] = serialize_event(
- event, time_now, token_id=token_id,
- event_format=format_event_for_client_v2_without_event_id,
- )
- timeline_event_ids.append(event.event_id)
+ serialized_state = [serialize(e) for e in state_events]
+ serialized_timeline = [serialize(e) for e in timeline_events]
- private_user_data = filter.filter_room_private_user_data(
- room.private_user_data
+ account_data = filter.filter_room_account_data(
+ room.account_data
)
result = {
- "event_map": event_map,
"timeline": {
- "events": timeline_event_ids,
+ "events": serialized_timeline,
"prev_batch": room.timeline.prev_batch.to_string(),
"limited": room.timeline.limited,
},
- "state": {"events": state_event_ids},
- "private_user_data": {"events": private_user_data},
+ "state": {"events": serialized_state},
+ "account_data": {"events": account_data},
}
if joined:
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 35482ae6a6..ba7223be11 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -81,7 +81,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event(
- "private_user_data_key", max_id, users=[user_id]
+ "account_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
@@ -95,7 +95,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
- "private_user_data_key", max_id, users=[user_id]
+ "account_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index ab8b4d44ea..bfb7386035 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -17,12 +17,11 @@ var submitPassword = function(user, pwd) {
}).error(errorFunc);
};
-var submitCas = function(ticket, service) {
- console.log("Logging in with cas...");
+var submitToken = function(loginToken) {
+ console.log("Logging in with login token...");
var data = {
- type: "m.login.cas",
- ticket: ticket,
- service: service,
+ type: "m.login.token",
+ token: loginToken
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login();
@@ -41,23 +40,10 @@ var errorFunc = function(err) {
}
};
-var getCasURL = function(cb) {
- $.get(matrixLogin.endpoint + "/cas", function(response) {
- var cas_url = response.serverUrl;
-
- cb(cas_url);
- }).error(errorFunc);
-};
-
-
var gotoCas = function() {
- getCasURL(function(cas_url) {
- var this_page = window.location.origin + window.location.pathname;
-
- var redirect_url = cas_url + "/login?service=" + encodeURIComponent(this_page);
-
- window.location.replace(redirect_url);
- });
+ var this_page = window.location.origin + window.location.pathname;
+ var redirect_url = matrixLogin.endpoint + "/cas/redirect?redirectUrl=" + encodeURIComponent(this_page);
+ window.location.replace(redirect_url);
}
var setFeedbackString = function(text) {
@@ -111,7 +97,7 @@ var fetch_info = function(cb) {
matrixLogin.onLoad = function() {
fetch_info(function() {
- if (!try_cas()) {
+ if (!try_token()) {
show_login();
}
});
@@ -148,20 +134,20 @@ var parseQsFromUrl = function(query) {
return result;
};
-var try_cas = function() {
+var try_token = function() {
var pos = window.location.href.indexOf("?");
if (pos == -1) {
return false;
}
var qs = parseQsFromUrl(window.location.href.substr(pos+1));
- var ticket = qs.ticket;
+ var loginToken = qs.loginToken;
- if (!ticket) {
+ if (!loginToken) {
return false;
}
- submitCas(ticket, location.origin);
+ submitToken(loginToken);
return true;
};
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1a74d6e360..9800fd4203 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 25
+SCHEMA_VERSION = 26
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index ae1ad56d9a..d32ce1ab1e 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -160,7 +160,7 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
- where_clause = "user_id = ? AND (%s)" % (
+ where_clause = "user_id = ? AND (%s) AND NOT forgotten" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
@@ -269,3 +269,67 @@ class RoomMemberStore(SQLBaseStore):
ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
defer.returnValue(ret)
+
+ def forget(self, user_id, room_id):
+ """Indicate that user_id wishes to discard history for room_id."""
+ def f(txn):
+ sql = (
+ "UPDATE"
+ " room_memberships"
+ " SET"
+ " forgotten = 1"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id))
+ self.runInteraction("forget_membership", f)
+
+ @defer.inlineCallbacks
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ count = yield self.runInteraction("did_forget_membership", f)
+ defer.returnValue(count == 0)
+
+ @defer.inlineCallbacks
+ def was_forgotten_at(self, user_id, room_id, event_id):
+ """Returns whether user_id has elected to discard history for room_id at event_id.
+
+ event_id must be a membership event."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " forgotten"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " event_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id, event_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ forgot = yield self.runInteraction("did_forget_membership_at", f)
+ defer.returnValue(forgot == 1)
diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/schema/delta/26/account_data.sql
new file mode 100644
index 0000000000..3198a0d29c
--- /dev/null
+++ b/synapse/storage/schema/delta/26/account_data.sql
@@ -0,0 +1,17 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id;
diff --git a/synapse/storage/schema/delta/26/forgotten_memberships.sql b/synapse/storage/schema/delta/26/forgotten_memberships.sql
new file mode 100644
index 0000000000..df55b9c6f6
--- /dev/null
+++ b/synapse/storage/schema/delta/26/forgotten_memberships.sql
@@ -0,0 +1,24 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Keeps track of what rooms users have left and don't want to be able to
+ * access again.
+ *
+ * If all users on this server have left a room, we can delete the room
+ * entirely.
+ */
+
+ ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER(1) DEFAULT 0;
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index bf695b7800..f6d826cc59 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -28,17 +28,17 @@ class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
- self._private_user_data_id_gen = StreamIdGenerator(
- "private_user_data_max_stream_id", "stream_id"
+ self._account_data_id_gen = StreamIdGenerator(
+ "account_data_max_stream_id", "stream_id"
)
- def get_max_private_user_data_stream_id(self):
+ def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
- return self._private_user_data_id_gen.get_max_token(self)
+ return self._account_data_id_gen.get_max_token(self)
@cached()
def get_tags_for_user(self, user_id):
@@ -144,12 +144,12 @@ class TagsStore(SQLBaseStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
+ with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._private_user_data_id_gen.get_max_token(self)
+ result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
@@ -166,12 +166,12 @@ class TagsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
+ with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._private_user_data_id_gen.get_max_token(self)
+ result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
@@ -185,7 +185,7 @@ class TagsStore(SQLBaseStore):
"""
update_max_id_sql = (
- "UPDATE private_user_data_max_stream_id"
+ "UPDATE account_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index f0d68b5bf2..cfa7d30fa5 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -21,7 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource
-from synapse.handlers.private_user_data import PrivateUserDataEventSource
+from synapse.handlers.account_data import AccountDataEventSource
class EventSources(object):
@@ -30,7 +30,7 @@ class EventSources(object):
"presence": PresenceEventSource,
"typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource,
- "private_user_data": PrivateUserDataEventSource,
+ "account_data": AccountDataEventSource,
}
def __init__(self, hs):
@@ -54,8 +54,8 @@ class EventSources(object):
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
- private_user_data_key=(
- yield self.sources["private_user_data"].get_current_key()
+ account_data_key=(
+ yield self.sources["account_data"].get_current_key()
),
)
defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 28344d8b36..af1d76ab46 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -103,7 +103,7 @@ class StreamToken(
"presence_key",
"typing_key",
"receipt_key",
- "private_user_data_key",
+ "account_data_key",
))
):
_SEPARATOR = "_"
@@ -138,7 +138,7 @@ class StreamToken(
or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key))
- or (int(other.private_user_data_key) < int(self.private_user_data_key))
+ or (int(other.account_data_key) < int(self.account_data_key))
)
def copy_and_advance(self, key, new_value):
|