diff options
Diffstat (limited to 'crypto/src/tls/Certificate.cs')
-rw-r--r-- | crypto/src/tls/Certificate.cs | 294 |
1 files changed, 294 insertions, 0 deletions
diff --git a/crypto/src/tls/Certificate.cs b/crypto/src/tls/Certificate.cs new file mode 100644 index 000000000..fef35fc1e --- /dev/null +++ b/crypto/src/tls/Certificate.cs @@ -0,0 +1,294 @@ +using System; +using System.Collections; +using System.IO; + +using Org.BouncyCastle.Tls.Crypto; +using Org.BouncyCastle.Utilities; + +namespace Org.BouncyCastle.Tls +{ + /// <summary>Parsing and encoding of a <i>Certificate</i> struct from RFC 4346.</summary> + /// <remarks> + /// <pre> + /// opaque ASN.1Cert<2^24-1>; + /// struct { + /// ASN.1Cert certificate_list<0..2^24-1>; + /// } Certificate; + /// </pre> + /// </remarks> + public sealed class Certificate + { + private static readonly TlsCertificate[] EmptyCerts = new TlsCertificate[0]; + private static readonly CertificateEntry[] EmptyCertEntries = new CertificateEntry[0]; + + public static readonly Certificate EmptyChain = new Certificate(EmptyCerts); + public static readonly Certificate EmptyChainTls13 = new Certificate(TlsUtilities.EmptyBytes, EmptyCertEntries); + + public sealed class ParseOptions + { + private int m_maxChainLength = int.MaxValue; + + public int MaxChainLength + { + get { return m_maxChainLength; } + } + + public ParseOptions SetMaxChainLength(int maxChainLength) + { + this.m_maxChainLength = maxChainLength; + return this; + } + } + + private static CertificateEntry[] Convert(TlsCertificate[] certificateList) + { + if (TlsUtilities.IsNullOrContainsNull(certificateList)) + throw new ArgumentException("cannot be null or contain any nulls", "certificateList"); + + int count = certificateList.Length; + CertificateEntry[] result = new CertificateEntry[count]; + for (int i = 0; i < count; ++i) + { + result[i] = new CertificateEntry(certificateList[i], null); + } + return result; + } + + private readonly byte[] m_certificateRequestContext; + private readonly CertificateEntry[] m_certificateEntryList; + + public Certificate(TlsCertificate[] certificateList) + : this(null, Convert(certificateList)) + { + } + + // TODO[tls13] Prefer to manage the certificateRequestContext internally only? + public Certificate(byte[] certificateRequestContext, CertificateEntry[] certificateEntryList) + { + if (null != certificateRequestContext && !TlsUtilities.IsValidUint8(certificateRequestContext.Length)) + throw new ArgumentException("cannot be longer than 255", "certificateRequestContext"); + if (TlsUtilities.IsNullOrContainsNull(certificateEntryList)) + throw new ArgumentException("cannot be null or contain any nulls", "certificateEntryList"); + + this.m_certificateRequestContext = TlsUtilities.Clone(certificateRequestContext); + this.m_certificateEntryList = certificateEntryList; + } + + public byte[] GetCertificateRequestContext() + { + return TlsUtilities.Clone(m_certificateRequestContext); + } + + /// <returns>an array of <see cref="TlsCertificate"/> representing a certificate chain.</returns> + public TlsCertificate[] GetCertificateList() + { + return CloneCertificateList(); + } + + public TlsCertificate GetCertificateAt(int index) + { + return m_certificateEntryList[index].Certificate; + } + + public CertificateEntry GetCertificateEntryAt(int index) + { + return m_certificateEntryList[index]; + } + + public CertificateEntry[] GetCertificateEntryList() + { + return CloneCertificateEntryList(); + } + + public short CertificateType + { + get { return Tls.CertificateType.X509; } + } + + public int Length + { + get { return m_certificateEntryList.Length; } + } + + /// <returns><c>true</c> if this certificate chain contains no certificates, or <c>false</c> otherwise. + /// </returns> + public bool IsEmpty + { + get { return m_certificateEntryList.Length == 0; } + } + + /// <summary>Encode this <see cref="Certificate"/> to a <see cref="Stream"/>, and optionally calculate the + /// "end point hash" (per RFC 5929's tls-server-end-point binding).</summary> + /// <param name="context">the <see cref="TlsContext"/> of the current connection.</param> + /// <param name="messageOutput">the <see cref="Stream"/> to encode to.</param> + /// <param name="endPointHashOutput">the <see cref="Stream"/> to write the "end point hash" to (or null). + /// </param> + /// <exception cref="IOException"/> + public void Encode(TlsContext context, Stream messageOutput, Stream endPointHashOutput) + { + bool isTlsV13 = TlsUtilities.IsTlsV13(context); + + if ((null != m_certificateRequestContext) != isTlsV13) + throw new InvalidOperationException(); + + if (isTlsV13) + { + TlsUtilities.WriteOpaque8(m_certificateRequestContext, messageOutput); + } + + int count = m_certificateEntryList.Length; + IList certEncodings = Platform.CreateArrayList(count); + IList extEncodings = isTlsV13 ? Platform.CreateArrayList(count) : null; + + long totalLength = 0; + for (int i = 0; i < count; ++i) + { + CertificateEntry entry = m_certificateEntryList[i]; + TlsCertificate cert = entry.Certificate; + byte[] derEncoding = cert.GetEncoded(); + + if (i == 0 && endPointHashOutput != null) + { + CalculateEndPointHash(context, cert, derEncoding, endPointHashOutput); + } + + certEncodings.Add(derEncoding); + totalLength += derEncoding.Length; + totalLength += 3; + + if (isTlsV13) + { + IDictionary extensions = entry.Extensions; + byte[] extEncoding = (null == extensions) + ? TlsUtilities.EmptyBytes + : TlsProtocol.WriteExtensionsData(extensions); + + extEncodings.Add(extEncoding); + totalLength += extEncoding.Length; + totalLength += 2; + } + } + + TlsUtilities.CheckUint24(totalLength); + TlsUtilities.WriteUint24((int)totalLength, messageOutput); + + for (int i = 0; i < count; ++i) + { + byte[] certEncoding = (byte[])certEncodings[i]; + TlsUtilities.WriteOpaque24(certEncoding, messageOutput); + + if (isTlsV13) + { + byte[] extEncoding = (byte[])extEncodings[i]; + TlsUtilities.WriteOpaque16(extEncoding, messageOutput); + } + } + } + + /// <summary>Parse a <see cref="Certificate"/> from a <see cref="Stream"/>.</summary> + /// <param name="options">the <see cref="ParseOptions"/> to apply during parsing.</param> + /// <param name="context">the <see cref="TlsContext"/> of the current connection.</param> + /// <param name="messageInput">the <see cref="Stream"/> to parse from.</param> + /// <param name="endPointHashOutput">the <see cref="Stream"/> to write the "end point hash" to (or null). + /// </param> + /// <returns>a <see cref="Certificate"/> object.</returns> + /// <exception cref="IOException"/> + public static Certificate Parse(ParseOptions options, TlsContext context, Stream messageInput, + Stream endPointHashOutput) + { + SecurityParameters securityParameters = context.SecurityParameters; + bool isTlsV13 = TlsUtilities.IsTlsV13(securityParameters.NegotiatedVersion); + + byte[] certificateRequestContext = null; + if (isTlsV13) + { + certificateRequestContext = TlsUtilities.ReadOpaque8(messageInput); + } + + int totalLength = TlsUtilities.ReadUint24(messageInput); + if (totalLength == 0) + { + return !isTlsV13 ? EmptyChain + : certificateRequestContext.Length < 1 ? EmptyChainTls13 + : new Certificate(certificateRequestContext, EmptyCertEntries); + } + + byte[] certListData = TlsUtilities.ReadFully(totalLength, messageInput); + MemoryStream buf = new MemoryStream(certListData, false); + + TlsCrypto crypto = context.Crypto; + int maxChainLength = System.Math.Max(1, options.MaxChainLength); + + IList certificate_list = Platform.CreateArrayList(); + while (buf.Position < buf.Length) + { + if (certificate_list.Count >= maxChainLength) + { + throw new TlsFatalAlert(AlertDescription.internal_error, + "Certificate chain longer than maximum (" + maxChainLength + ")"); + } + + byte[] derEncoding = TlsUtilities.ReadOpaque24(buf, 1); + TlsCertificate cert = crypto.CreateCertificate(derEncoding); + + if (certificate_list.Count < 1 && endPointHashOutput != null) + { + CalculateEndPointHash(context, cert, derEncoding, endPointHashOutput); + } + + IDictionary extensions = null; + if (isTlsV13) + { + byte[] extEncoding = TlsUtilities.ReadOpaque16(buf); + + extensions = TlsProtocol.ReadExtensionsData13(HandshakeType.certificate, extEncoding); + } + + certificate_list.Add(new CertificateEntry(cert, extensions)); + } + + CertificateEntry[] certificateList = new CertificateEntry[certificate_list.Count]; + for (int i = 0; i < certificate_list.Count; i++) + { + certificateList[i] = (CertificateEntry)certificate_list[i]; + } + + return new Certificate(certificateRequestContext, certificateList); + } + + private static void CalculateEndPointHash(TlsContext context, TlsCertificate cert, byte[] encoding, + Stream output) + { + byte[] endPointHash = TlsUtilities.CalculateEndPointHash(context, cert, encoding); + if (endPointHash != null && endPointHash.Length > 0) + { + output.Write(endPointHash, 0, endPointHash.Length); + } + } + + private TlsCertificate[] CloneCertificateList() + { + int count = m_certificateEntryList.Length; + if (0 == count) + return EmptyCerts; + + TlsCertificate[] result = new TlsCertificate[count]; + for (int i = 0; i < count; ++i) + { + result[i] = m_certificateEntryList[i].Certificate; + } + return result; + } + + private CertificateEntry[] CloneCertificateEntryList() + { + int count = m_certificateEntryList.Length; + if (0 == count) + return EmptyCertEntries; + + CertificateEntry[] result = new CertificateEntry[count]; + Array.Copy(m_certificateEntryList, 0, result, 0, count); + return result; + } + } +} |