changeset 9699:e960bce2d584

8210870: Libsunmscapi improved interactions Reviewed-by: valeriep, mschoene, rhalade
author igerasim
date Thu, 31 Jan 2019 04:57:12 +0000
parents 0fdbcf2eb794
children 0d2ec8760dbc
files src/windows/classes/sun/security/mscapi/KeyStore.java src/windows/native/sun/security/mscapi/security.cpp
diffstat 2 files changed, 100 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/src/windows/classes/sun/security/mscapi/KeyStore.java	Thu Jan 31 04:54:52 2019 +0000
+++ b/src/windows/classes/sun/security/mscapi/KeyStore.java	Thu Jan 31 04:57:12 2019 +0000
@@ -757,6 +757,7 @@
     /**
      * Generates a certificate chain from the collection of
      * certificates and stores the result into a key entry.
+     * This method is called by native code in libsunmscapi.
      */
     private void generateCertificateChain(String alias,
         Collection<? extends Certificate> certCollection)
@@ -779,13 +780,15 @@
         catch (Throwable e)
         {
             // Ignore the exception and skip this entry
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
     }
 
     /**
      * Generates RSA key and certificate chain from the private key handle,
      * collection of certificates and stores the result into key entries.
+     * This method is called by native code in libsunmscapi.
      */
     private void generateRSAKeyAndCertificateChain(String alias,
         long hCryptProv, long hCryptKey, int keyLength,
@@ -810,12 +813,14 @@
         catch (Throwable e)
         {
             // Ignore the exception and skip this entry
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
     }
 
     /**
      * Generates certificates from byte data and stores into cert collection.
+     * This method is called by native code in libsunmscapi.
      *
      * @param data Byte data.
      * @param certCollection Collection of certificates.
@@ -839,12 +844,14 @@
         catch (CertificateException e)
         {
             // Ignore the exception and skip this certificate
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
         catch (Throwable te)
         {
             // Ignore the exception and skip this certificate
-            // TODO - throw CertificateException?
+            // If e is thrown, remember to deal with it in
+            // native code.
         }
     }
 
--- a/src/windows/native/sun/security/mscapi/security.cpp	Thu Jan 31 04:54:52 2019 +0000
+++ b/src/windows/native/sun/security/mscapi/security.cpp	Thu Jan 31 04:57:12 2019 +0000
@@ -414,6 +414,15 @@
                     // Create ArrayList to store certs in each chain
                     jobject jArrayList =
                         env->NewObject(clazzArrayList, mNewArrayList);
+                    if (jArrayList == NULL) {
+                        __leave;
+                    }
+
+                    // Cleanup the previous allocated name
+                    if (pszNameString) {
+                        delete [] pszNameString;
+                        pszNameString = NULL;
+                    }
 
                     for (unsigned int j=0; j < rgpChain->cElement; j++)
                     {
@@ -452,6 +461,9 @@
 
                         // Allocate and populate byte array
                         jbyteArray byteArray = env->NewByteArray(cbCertEncoded);
+                        if (byteArray == NULL) {
+                            __leave;
+                        }
                         env->SetByteArrayRegion(byteArray, 0, cbCertEncoded,
                             (jbyte*) pbCertEncoded);
 
@@ -459,30 +471,44 @@
                         // cert collection
                         env->CallVoidMethod(obj, mGenCert, byteArray, jArrayList);
                     }
-                    if (bHasNoPrivateKey)
-                    {
-                        // Generate certificate chain and store into cert chain
-                        // collection
-                        env->CallVoidMethod(obj, mGenCertChain,
-                            env->NewStringUTF(pszNameString),
-                            jArrayList);
-                    }
-                    else
+                    // Usually pszNameString should be non-NULL. It's either
+                    // the friendly name or an element from the subject name
+                    // or SAN.
+                    if (pszNameString)
                     {
-                        // Determine key type: RSA or DSA
-                        DWORD dwData = CALG_RSA_KEYX;
-                        DWORD dwSize = sizeof(DWORD);
-                        ::CryptGetKeyParam(hUserKey, KP_ALGID, (BYTE*)&dwData,
-                                &dwSize, NULL);
+                        if (bHasNoPrivateKey)
+                        {
+                            // Generate certificate chain and store into cert chain
+                            // collection
+                            jstring name = env->NewStringUTF(pszNameString);
+                            if (name == NULL) {
+                                __leave;
+                            }
+                            env->CallVoidMethod(obj, mGenCertChain,
+                                name,
+                                jArrayList);
+                        }
+                        else
+                        {
+                            // Determine key type: RSA or DSA
+                            DWORD dwData = CALG_RSA_KEYX;
+                            DWORD dwSize = sizeof(DWORD);
+                            ::CryptGetKeyParam(hUserKey, KP_ALGID, (BYTE*)&dwData,
+                                    &dwSize, NULL);
 
-                        if ((dwData & ALG_TYPE_RSA) == ALG_TYPE_RSA)
-                        {
-                            // Generate RSA certificate chain and store into cert
-                            // chain collection
-                            env->CallVoidMethod(obj, mGenRSAKeyAndCertChain,
-                                    env->NewStringUTF(pszNameString),
-                                    (jlong) hCryptProv, (jlong) hUserKey,
-                                    dwPublicKeyLength, jArrayList);
+                            if ((dwData & ALG_TYPE_RSA) == ALG_TYPE_RSA)
+                            {
+                                // Generate RSA certificate chain and store into cert
+                                // chain collection
+                                jstring name = env->NewStringUTF(pszNameString);
+                                if (name == NULL) {
+                                    __leave;
+                                }
+                                env->CallVoidMethod(obj, mGenRSAKeyAndCertChain,
+                                        name,
+                                        (jlong) hCryptProv, (jlong) hUserKey,
+                                        dwPublicKeyLength, jArrayList);
+                            }
                         }
                     }
                 }
@@ -629,6 +655,9 @@
 
         // Create new byte array
         jbyteArray temp = env->NewByteArray(dwBufLen);
+        if (temp == NULL) {
+            __leave;
+        }
 
         // Copy data from native buffer
         env->SetByteArrayRegion(temp, 0, dwBufLen, pSignedHashBuffer);
@@ -952,6 +981,9 @@
         }
 
         jCertAliasChars = env->GetStringChars(jCertAliasName, NULL);
+        if (jCertAliasChars == NULL) {
+            __leave;
+        }
         memcpy(pszCertAliasName, jCertAliasChars, size * sizeof(WCHAR));
         pszCertAliasName[size] = 0; // append the string terminator
 
@@ -1574,7 +1606,9 @@
         }
 
         // Create new byte array
-        result = env->NewByteArray(dwBufLen);
+        if ((result = env->NewByteArray(dwBufLen)) == NULL) {
+            __leave;
+        }
 
         // Copy data from native buffer to Java buffer
         env->SetByteArrayRegion(result, 0, dwBufLen, (jbyte*) pData);
@@ -1625,7 +1659,9 @@
         }
 
         // Create new byte array
-        blob = env->NewByteArray(dwBlobLen);
+        if ((blob = env->NewByteArray(dwBlobLen)) == NULL) {
+            __leave;
+        }
 
         // Copy data from native buffer to Java buffer
         env->SetByteArrayRegion(blob, 0, dwBlobLen, (jbyte*) pbKeyBlob);
@@ -1654,6 +1690,13 @@
     __try {
 
         jsize length = env->GetArrayLength(jKeyBlob);
+        jsize headerLength = sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY);
+
+        if (length < headerLength) {
+            ThrowExceptionWithMessage(env, KEY_EXCEPTION, "Invalid BLOB");
+            __leave;
+        }
+
         if ((keyBlob = env->GetByteArrayElements(jKeyBlob, 0)) == NULL) {
             __leave;
         }
@@ -1680,7 +1723,9 @@
             exponentBytes[i] = ((BYTE*) &pRsaPubKey->pubexp)[j];
         }
 
-        exponent = env->NewByteArray(len);
+        if ((exponent = env->NewByteArray(len)) == NULL) {
+            __leave;
+        }
         env->SetByteArrayRegion(exponent, 0, len, exponentBytes);
     }
     __finally
@@ -1710,6 +1755,13 @@
     __try {
 
         jsize length = env->GetArrayLength(jKeyBlob);
+        jsize headerLength = sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY);
+
+        if (length < headerLength) {
+            ThrowExceptionWithMessage(env, KEY_EXCEPTION, "Invalid BLOB");
+            __leave;
+        }
+
         if ((keyBlob = env->GetByteArrayElements(jKeyBlob, 0)) == NULL) {
             __leave;
         }
@@ -1726,19 +1778,25 @@
             (RSAPUBKEY *) (keyBlob + sizeof(PUBLICKEYSTRUC));
 
         int len = pRsaPubKey->bitlen / 8;
+        if (len < 0 || len > length - headerLength) {
+            ThrowExceptionWithMessage(env, KEY_EXCEPTION, "Invalid key length");
+            __leave;
+        }
+
         modulusBytes = new (env) jbyte[len];
         if (modulusBytes == NULL) {
             __leave;
         }
-        BYTE * pbModulus =
-            (BYTE *) (keyBlob + sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY));
+        BYTE * pbModulus = (BYTE *) (keyBlob + headerLength);
 
         // convert from little-endian while copying from blob
         for (int i = 0, j = len - 1; i < len; i++, j--) {
             modulusBytes[i] = pbModulus[j];
         }
 
-        modulus = env->NewByteArray(len);
+        if ((modulus = env->NewByteArray(len)) == NULL) {
+            __leave;
+        }
         env->SetByteArrayRegion(modulus, 0, len, modulusBytes);
     }
     __finally
@@ -1946,7 +2004,9 @@
             }
         }
 
-        jBlob = env->NewByteArray(jBlobLength);
+        if ((jBlob = env->NewByteArray(jBlobLength)) == NULL) {
+            __leave;
+        }
         env->SetByteArrayRegion(jBlob, 0, jBlobLength, jBlobBytes);
 
     }