summary refs log tree commit diff
path: root/crypto/src/asn1/DerInteger.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/asn1/DerInteger.cs')
-rw-r--r--crypto/src/asn1/DerInteger.cs137
1 files changed, 109 insertions, 28 deletions
diff --git a/crypto/src/asn1/DerInteger.cs b/crypto/src/asn1/DerInteger.cs
index ae14d2a9f..3e19a07b6 100644
--- a/crypto/src/asn1/DerInteger.cs
+++ b/crypto/src/asn1/DerInteger.cs
@@ -16,7 +16,11 @@ namespace Org.BouncyCastle.Asn1
             return allowUnsafeValue != null && Platform.EqualsIgnoreCase("true", allowUnsafeValue);
         }
 
+        internal const int SignExtSigned = -1;
+        internal const int SignExtUnsigned = 0xFF;
+
         private readonly byte[] bytes;
+        private readonly int start;
 
         /**
          * return an integer from the passed in object
@@ -60,42 +64,42 @@ namespace Org.BouncyCastle.Asn1
 			return new DerInteger(Asn1OctetString.GetInstance(o).GetOctets());
         }
 
-		public DerInteger(
-            int value)
+		public DerInteger(int value)
+        {
+            this.bytes = BigInteger.ValueOf(value).ToByteArray();
+            this.start = 0;
+        }
+
+        public DerInteger(long value)
         {
-            bytes = BigInteger.ValueOf(value).ToByteArray();
+            this.bytes = BigInteger.ValueOf(value).ToByteArray();
+            this.start = 0;
         }
 
-		public DerInteger(
-            BigInteger value)
+		public DerInteger(BigInteger value)
         {
             if (value == null)
                 throw new ArgumentNullException("value");
 
-			bytes = value.ToByteArray();
+			this.bytes = value.ToByteArray();
+            this.start = 0;
         }
 
-		public DerInteger(
-            byte[] bytes)
+        public DerInteger(byte[] bytes)
+            : this(bytes, true)
         {
-            if (bytes.Length > 1)
-            {
-                if ((bytes[0] == 0 && (bytes[1] & 0x80) == 0)
-                    || (bytes[0] == (byte)0xff && (bytes[1] & 0x80) != 0))
-                {
-                    if (!AllowUnsafe())
-                        throw new ArgumentException("malformed integer");
-                }
-            }
-            this.bytes = Arrays.Clone(bytes);
         }
 
-		public BigInteger Value
+        internal DerInteger(byte[] bytes, bool clone)
         {
-            get { return new BigInteger(bytes); }
+            if (IsMalformed(bytes))
+                throw new ArgumentException("malformed integer", "bytes");
+
+            this.bytes = clone ? Arrays.Clone(bytes) : bytes;
+            this.start = SignBytesToSkip(bytes);
         }
 
-		/**
+        /**
          * in some cases positive values Get crammed into a space,
          * that's not quite big enough...
          */
@@ -104,8 +108,44 @@ namespace Org.BouncyCastle.Asn1
             get { return new BigInteger(1, bytes); }
         }
 
-        internal override void Encode(
-            DerOutputStream derOut)
+        public BigInteger Value
+        {
+            get { return new BigInteger(bytes); }
+        }
+
+        public bool HasValue(BigInteger x)
+        {
+            return null != x
+                // Fast check to avoid allocation
+                && IntValue(bytes, start, SignExtSigned) == x.IntValue
+                && Value.Equals(x);
+        }
+
+        public int IntPositiveValueExact
+        {
+            get
+            {
+                int count = bytes.Length - start;
+                if (count > 4 || (count == 4 && 0 != (bytes[start] & 0x80)))
+                    throw new ArithmeticException("ASN.1 Integer out of positive int range");
+
+                return IntValue(bytes, start, SignExtUnsigned);
+            }
+        }
+
+        public int IntValueExact
+        {
+            get
+            {
+                int count = bytes.Length - start;
+                if (count > 4)
+                    throw new ArithmeticException("ASN.1 Integer out of int range");
+
+                return IntValue(bytes, start, SignExtSigned);
+            }
+        }
+
+        internal override void Encode(DerOutputStream derOut)
         {
             derOut.WriteEncoded(Asn1Tags.Integer, bytes);
         }
@@ -115,20 +155,61 @@ namespace Org.BouncyCastle.Asn1
 			return Arrays.GetHashCode(bytes);
         }
 
-		protected override bool Asn1Equals(
-			Asn1Object asn1Object)
+		protected override bool Asn1Equals(Asn1Object asn1Object)
 		{
 			DerInteger other = asn1Object as DerInteger;
-
 			if (other == null)
 				return false;
 
-			return Arrays.AreEqual(this.bytes, other.bytes);
+            return Arrays.AreEqual(this.bytes, other.bytes);
         }
 
 		public override string ToString()
 		{
 			return Value.ToString();
 		}
-	}
+
+        internal static int IntValue(byte[] bytes, int start, int signExt)
+        {
+            int length = bytes.Length;
+            int pos = System.Math.Max(start, length - 4);
+
+            int val = (sbyte)bytes[pos] & signExt;
+            while (++pos < length)
+            {
+                val = (val << 8) | bytes[pos];
+            }
+            return val;
+        }
+
+        /**
+         * Apply the correct validation for an INTEGER primitive following the BER rules.
+         *
+         * @param bytes The raw encoding of the integer.
+         * @return true if the (in)put fails this validation.
+         */
+        internal static bool IsMalformed(byte[] bytes)
+        {
+            switch (bytes.Length)
+            {
+            case 0:
+                return true;
+            case 1:
+                return false;
+            default:
+                return (sbyte)bytes[0] == ((sbyte)bytes[1] >> 7) && !AllowUnsafe();
+            }
+        }
+
+        internal static int SignBytesToSkip(byte[] bytes)
+        {
+            int pos = 0, last = bytes.Length - 1;
+            while (pos < last
+                && (sbyte)bytes[pos] == ((sbyte)bytes[pos + 1] >> 7))
+            {
+                ++pos;
+            }
+            return pos;
+        }
+    }
 }